zitadel/internal/eventstore/repository/mock/repository.mock.impl.go
Livio Spring 911200aa9b
feat(api): allow Device Authorization Grant using custom login UI (#9387)
# Which Problems Are Solved

The OAuth2 Device Authorization Grant could not yet been handled through
the new login UI, resp. using the session API.
This PR adds the ability for the login UI to get the required
information to display the user and handle their decision (approve with
authorization or deny) using the OIDC Service API.

# How the Problems Are Solved

- Added a `GetDeviceAuthorizationRequest` endpoint, which allows getting
the `id`, `client_id`, `scope`, `app_name` and `project_name` of the
device authorization request
- Added a `AuthorizeOrDenyDeviceAuthorization` endpoint, which allows to
approve/authorize with the session information or deny the request. The
identification of the request is done by the `device_authorization_id` /
`id` returned in the previous request.
- To prevent leaking the `device_code` to the UI, but still having an
easy reference, it's encrypted and returned as `id`, resp. decrypted
when used.
- Fixed returned error types for device token responses on token
endpoint:
- Explicitly return `access_denied` (without internal error) when user
denied the request
  - Default to `invalid_grant` instead of `access_denied`
- Explicitly check on initial state when approving the reqeust
- Properly handle done case (also relates to initial check) 
- Documented the flow and handling in custom UIs (according to OIDC /
SAML)

# Additional Changes

- fixed some typos and punctuation in the corresponding OIDC / SAML
guides.
- added some missing translations for auth and saml request

# Additional Context

- closes #6239

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
2025-02-25 07:33:13 +01:00

235 lines
7.7 KiB
Go

package mock
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
type MockRepository struct {
*MockPusher
*MockQuerier
}
func NewRepo(t *testing.T) *MockRepository {
controller := gomock.NewController(t)
return &MockRepository{
MockPusher: NewMockPusher(controller),
MockQuerier: NewMockQuerier(controller),
}
}
func (m *MockRepository) ExpectFilterNoEventsNoError() *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
return m
}
func (m *MockRepository) ExpectFilterEvents(events ...eventstore.Event) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, _ *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) error {
for _, event := range events {
if err := reduce(event); err != nil {
return err
}
}
return nil
},
)
return m
}
func (m *MockRepository) ExpectFilterEventsError(err error) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).Return(err)
return m
}
func (m *MockRepository) ExpectInstanceIDs(hasFilters []*repository.Filter, instanceIDs ...string) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
matcher := gomock.Any()
if len(hasFilters) > 0 {
matcher = &filterQueryMatcher{SubQueries: [][]*repository.Filter{hasFilters}}
}
m.MockQuerier.EXPECT().InstanceIDs(gomock.Any(), matcher).Return(instanceIDs, nil)
return m
}
func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(nil, err)
return m
}
// ExpectPush checks if the expectedCommands are send to the Push method.
// The call will sleep at least the amount of passed duration.
func (m *MockRepository) ExpectPush(expectedCommands []eventstore.Command, sleep time.Duration) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
m.MockPusher.ctrl.T.Helper()
time.Sleep(sleep)
if len(expectedCommands) != len(commands) {
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
}
for i, expectedCommand := range expectedCommands {
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Aggregate(), commands[i].Aggregate()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Aggregate [%d]: expected: %#v got: %#v", i, expectedCommand.Aggregate(), commands[i].Aggregate())
}
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Creator(), commands[i].Creator()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Creator [%d]: expected: %#v got: %#v", i, expectedCommand.Creator(), commands[i].Creator())
}
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Type(), commands[i].Type()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Type [%d]: expected: %#v got: %#v", i, expectedCommand.Type(), commands[i].Type())
}
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Revision(), commands[i].Revision()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Revision [%d]: expected: %#v got: %#v", i, expectedCommand.Revision(), commands[i].Revision())
}
var expectedPayload []byte
expectedPayload, ok := expectedCommand.Payload().([]byte)
if !ok {
expectedPayload, _ = json.Marshal(expectedCommand.Payload())
}
if string(expectedPayload) == "" {
expectedPayload = []byte("null")
}
gotPayload, _ := json.Marshal(commands[i].Payload())
if !assert.Equal(m.MockPusher.ctrl.T, expectedPayload, gotPayload) {
m.MockPusher.ctrl.T.Errorf("invalid command.Payload [%d]: expected: %#v got: %#v", i, expectedCommand.Payload(), commands[i].Payload())
}
if !assert.ElementsMatch(m.MockPusher.ctrl.T, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints()) {
m.MockPusher.ctrl.T.Errorf("invalid command.UniqueConstraints [%d]: expected: %#v got: %#v", i, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints())
}
}
events := make([]eventstore.Event, len(commands))
for i, command := range commands {
events[i] = &mockEvent{
Command: command,
}
}
return events, nil
},
)
return m
}
func (m *MockRepository) ExpectPushFailed(err error, expectedCommands []eventstore.Command) *MockRepository {
m.MockPusher.ctrl.T.Helper()
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
if len(expectedCommands) != len(commands) {
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
}
for i, expectedCommand := range expectedCommands {
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Aggregate(), commands[i].Aggregate())
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Creator(), commands[i].Creator())
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Type(), commands[i].Type())
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Revision(), commands[i].Revision())
var expectedPayload []byte
expectedPayload, ok := expectedCommand.Payload().([]byte)
if !ok {
expectedPayload, _ = json.Marshal(expectedCommand.Payload())
}
if string(expectedPayload) == "" {
expectedPayload = []byte("null")
}
gotPayload, _ := json.Marshal(commands[i].Payload())
assert.Equal(m.MockPusher.ctrl.T, expectedPayload, gotPayload)
assert.ElementsMatch(m.MockPusher.ctrl.T, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints())
}
return nil, err
},
)
return m
}
type mockEvent struct {
eventstore.Command
sequence uint64
createdAt time.Time
}
// DataAsBytes implements eventstore.Event
func (e *mockEvent) DataAsBytes() []byte {
if e.Payload() == nil {
return nil
}
payload, err := json.Marshal(e.Payload())
if err != nil {
panic(err)
}
return payload
}
func (e *mockEvent) Unmarshal(ptr any) error {
if e.Payload() == nil {
return nil
}
payload, err := json.Marshal(e.Payload())
if err != nil {
return err
}
return json.Unmarshal(payload, ptr)
}
func (e *mockEvent) Sequence() uint64 {
return e.sequence
}
func (e *mockEvent) Position() float64 {
return 0
}
func (e *mockEvent) CreatedAt() time.Time {
return e.createdAt
}
func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
assert.Len(m.MockPusher.ctrl.T, commands, len(expectedCommands))
events := make([]eventstore.Event, len(commands))
for i, command := range commands {
events[i] = &mockEvent{
Command: command,
}
}
return events, nil
},
)
return m
}
func (m *MockRepository) ExpectRandomPushFailed(err error, expectedEvents []eventstore.Command) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) {
assert.Len(m.MockPusher.ctrl.T, events, len(expectedEvents))
return nil, err
},
)
return m
}