mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 08:07:32 +00:00
feat: v2alpha user service idp endpoints (#5879)
* feat: v2alpha user service idp endpoints * feat: v2alpha user service intent endpoints * begin idp intents (callback) * some cleanup * runnable idp authentication * cleanup * proto cleanup * retrieve idp info * improve success and failure handling * some unit tests * grpc unit tests * add permission check AddUserIDPLink * feat: v2alpha intent writemodel refactoring * feat: v2alpha intent writemodel refactoring * feat: v2alpha intent writemodel refactoring * provider from write model * fix idp type model and add integration tests * proto cleanup * fix integration test * add missing import * add more integration tests * auth url test * feat: v2alpha intent writemodel refactoring * remove unused functions * check token on RetrieveIdentityProviderInformation * feat: v2alpha intent writemodel refactoring * fix TestServer_RetrieveIdentityProviderInformation * fix test * i18n and linting * feat: v2alpha intent review changes --------- Co-authored-by: Livio Spring <livio.a@gmail.com> Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
@@ -629,7 +629,7 @@ func (s *Server) importData(ctx context.Context, orgs []*admin_pb.DataOrg) (*adm
|
||||
ExternalUserID: userLinks.ProvidedUserId,
|
||||
DisplayName: userLinks.ProvidedUserName,
|
||||
}
|
||||
if err := s.command.AddUserIDPLink(ctx, userLinks.UserId, org.GetOrgId(), externalIDP); err != nil {
|
||||
if _, err := s.command.AddUserIDPLink(ctx, userLinks.UserId, org.GetOrgId(), externalIDP); err != nil {
|
||||
errors = append(errors, &admin_pb.ImportDataError{Type: "user_link", Id: userLinks.UserId + "_" + userLinks.IdpId, Message: err.Error()})
|
||||
if isCtxTimeout(ctx) {
|
||||
return &admin_pb.ImportDataResponse{Errors: errors, Success: success}, count, err
|
||||
|
@@ -241,7 +241,6 @@ func AddHumanUserRequestToAddHuman(req *mgmt_pb.AddHumanUserRequest) *command.Ad
|
||||
PasswordChangeRequired: true,
|
||||
Passwordless: false,
|
||||
Register: false,
|
||||
ExternalIDP: false,
|
||||
}
|
||||
if req.Phone != nil {
|
||||
human.Phone = command.Phone{
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
@@ -18,15 +20,25 @@ type Server struct {
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
userCodeAlg crypto.EncryptionAlgorithm
|
||||
idpAlg crypto.EncryptionAlgorithm
|
||||
idpCallback func(ctx context.Context) string
|
||||
}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func CreateServer(command *command.Commands, query *query.Queries, userCodeAlg crypto.EncryptionAlgorithm) *Server {
|
||||
func CreateServer(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
userCodeAlg crypto.EncryptionAlgorithm,
|
||||
idpAlg crypto.EncryptionAlgorithm,
|
||||
idpCallback func(ctx context.Context) string,
|
||||
) *Server {
|
||||
return &Server{
|
||||
command: command,
|
||||
query: query,
|
||||
userCodeAlg: userCodeAlg,
|
||||
idpAlg: idpAlg,
|
||||
idpCallback: idpCallback,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,15 +2,19 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
@@ -56,6 +60,14 @@ func addUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman,
|
||||
Value: metadataEntry.GetValue(),
|
||||
}
|
||||
}
|
||||
links := make([]*command.AddLink, len(req.GetIdpLinks()))
|
||||
for i, link := range req.GetIdpLinks() {
|
||||
links[i] = &command.AddLink{
|
||||
IDPID: link.GetIdpId(),
|
||||
IDPExternalID: link.GetIdpExternalId(),
|
||||
DisplayName: link.GetDisplayName(),
|
||||
}
|
||||
}
|
||||
return &command.AddHuman{
|
||||
ID: req.GetUserId(),
|
||||
Username: username,
|
||||
@@ -76,9 +88,9 @@ func addUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman,
|
||||
BcryptedPassword: bcryptedPassword,
|
||||
PasswordChangeRequired: passwordChangeRequired,
|
||||
Passwordless: false,
|
||||
ExternalIDP: false,
|
||||
Register: false,
|
||||
Metadata: metadata,
|
||||
Links: links,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -107,3 +119,95 @@ func hashedPasswordToCommand(hashed *user.HashedPassword) (string, error) {
|
||||
}
|
||||
return hashed.GetHash(), nil
|
||||
}
|
||||
|
||||
func (s *Server) AddIDPLink(ctx context.Context, req *user.AddIDPLinkRequest) (_ *user.AddIDPLinkResponse, err error) {
|
||||
orgID := authz.GetCtxData(ctx).OrgID
|
||||
details, err := s.command.AddUserIDPLink(ctx, req.UserId, orgID, &domain.UserIDPLink{
|
||||
IDPConfigID: req.GetIdpLink().GetIdpId(),
|
||||
ExternalUserID: req.GetIdpLink().GetIdpExternalId(),
|
||||
DisplayName: req.GetIdpLink().GetDisplayName(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.AddIDPLinkResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) StartIdentityProviderFlow(ctx context.Context, req *user.StartIdentityProviderFlowRequest) (_ *user.StartIdentityProviderFlowResponse, err error) {
|
||||
id, details, err := s.command.CreateIntent(ctx, req.GetIdpId(), req.GetSuccessUrl(), req.GetFailureUrl(), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authURL, err := s.command.AuthURLFromProvider(ctx, req.GetIdpId(), id, s.idpCallback(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.StartIdentityProviderFlowResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{AuthUrl: authURL},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RetrieveIdentityProviderInformation(ctx context.Context, req *user.RetrieveIdentityProviderInformationRequest) (_ *user.RetrieveIdentityProviderInformationResponse, err error) {
|
||||
intent, err := s.command.GetIntentWriteModel(ctx, req.GetIntentId(), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.checkIntentToken(req.GetToken(), intent.AggregateID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if intent.State != domain.IDPIntentStateSucceeded {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "IDP-Hk38e", "Errors.Intent.NotSucceeded")
|
||||
}
|
||||
return intentToIDPInformationPb(intent, s.idpAlg)
|
||||
}
|
||||
|
||||
func intentToIDPInformationPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *user.RetrieveIdentityProviderInformationResponse, err error) {
|
||||
var idToken *string
|
||||
if intent.IDPIDToken != "" {
|
||||
idToken = &intent.IDPIDToken
|
||||
}
|
||||
var accessToken string
|
||||
if intent.IDPAccessToken != nil {
|
||||
accessToken, err = crypto.DecryptString(intent.IDPAccessToken, alg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &user.RetrieveIdentityProviderInformationResponse{
|
||||
Details: &object_pb.Details{
|
||||
Sequence: intent.ProcessedSequence,
|
||||
ChangeDate: timestamppb.New(intent.ChangeDate),
|
||||
ResourceOwner: intent.ResourceOwner,
|
||||
},
|
||||
IdpInformation: &user.IDPInformation{
|
||||
Access: &user.IDPInformation_Oauth{
|
||||
Oauth: &user.IDPOAuthAccessInformation{
|
||||
AccessToken: accessToken,
|
||||
IdToken: idToken,
|
||||
},
|
||||
},
|
||||
IdpInformation: intent.IDPUser,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkIntentToken(token string, intentID string) error {
|
||||
if token == "" {
|
||||
return errors.ThrowPermissionDenied(nil, "IDP-Sfefs", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(err, "IDP-Swg31", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
decryptedToken, err := s.idpAlg.Decrypt(data, s.idpAlg.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(err, "IDP-Sf4gt", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
if string(decryptedToken) != intentID {
|
||||
return errors.ThrowPermissionDenied(nil, "IDP-dkje3", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -6,16 +6,24 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/repository/idp"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,7 +47,60 @@ func TestMain(m *testing.M) {
|
||||
}())
|
||||
}
|
||||
|
||||
func createProvider(t *testing.T) string {
|
||||
ctx := authz.WithInstance(context.Background(), Tester.Instance)
|
||||
id, _, err := Tester.Commands.AddOrgGenericOAuthProvider(ctx, Tester.Organisation.ID, command.GenericOAuthProvider{
|
||||
"idp",
|
||||
"clientID",
|
||||
"clientSecret",
|
||||
"https://example.com/oauth/v2/authorize",
|
||||
"https://example.com/oauth/v2/token",
|
||||
"https://api.example.com/user",
|
||||
[]string{"openid", "profile", "email"},
|
||||
"id",
|
||||
idp.Options{
|
||||
IsLinkingAllowed: true,
|
||||
IsCreationAllowed: true,
|
||||
IsAutoCreation: true,
|
||||
IsAutoUpdate: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}
|
||||
|
||||
func createIntent(t *testing.T, idpID string) string {
|
||||
ctx := authz.WithInstance(context.Background(), Tester.Instance)
|
||||
id, _, err := Tester.Commands.CreateIntent(ctx, idpID, "https://example.com/success", "https://example.com/failure", Tester.Organisation.ID)
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}
|
||||
|
||||
func createSuccessfulIntent(t *testing.T, idpID string) (string, string, time.Time, uint64) {
|
||||
ctx := authz.WithInstance(context.Background(), Tester.Instance)
|
||||
intentID := createIntent(t, idpID)
|
||||
writeModel, err := Tester.Commands.GetIntentWriteModel(ctx, intentID, Tester.Organisation.ID)
|
||||
require.NoError(t, err)
|
||||
idpUser := &oauth.UserMapper{
|
||||
RawInfo: map[string]interface{}{
|
||||
"id": "id",
|
||||
},
|
||||
}
|
||||
idpSession := &oauth.Session{
|
||||
Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
},
|
||||
IDToken: "idToken",
|
||||
},
|
||||
}
|
||||
token, err := Tester.Commands.SucceedIDPIntent(ctx, writeModel, idpUser, idpSession, "")
|
||||
require.NoError(t, err)
|
||||
return intentID, token, writeModel.ChangeDate, writeModel.ProcessedSequence
|
||||
}
|
||||
|
||||
func TestServer_AddHumanUser(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.AddHumanUserRequest
|
||||
@@ -287,6 +348,105 @@ func TestServer_AddHumanUser(t *testing.T) {
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing idp",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddHumanUserRequest{
|
||||
Organisation: &object.Organisation{
|
||||
Org: &object.Organisation_OrgId{
|
||||
OrgId: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
Profile: &user.SetHumanProfile{
|
||||
FirstName: "Donald",
|
||||
LastName: "Duck",
|
||||
NickName: gu.Ptr("Dukkie"),
|
||||
DisplayName: gu.Ptr("Donald Duck"),
|
||||
PreferredLanguage: gu.Ptr("en"),
|
||||
Gender: user.Gender_GENDER_DIVERSE.Enum(),
|
||||
},
|
||||
Email: &user.SetHumanEmail{
|
||||
Email: "livio@zitadel.com",
|
||||
Verification: &user.SetHumanEmail_IsVerified{
|
||||
IsVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: []*user.SetMetadataEntry{
|
||||
{
|
||||
Key: "somekey",
|
||||
Value: []byte("somevalue"),
|
||||
},
|
||||
},
|
||||
PasswordType: &user.AddHumanUserRequest_Password{
|
||||
Password: &user.Password{
|
||||
Password: "DifficultPW666!",
|
||||
ChangeRequired: false,
|
||||
},
|
||||
},
|
||||
IdpLinks: []*user.IDPLink{
|
||||
{
|
||||
IdpId: "idpID",
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "with idp",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddHumanUserRequest{
|
||||
Organisation: &object.Organisation{
|
||||
Org: &object.Organisation_OrgId{
|
||||
OrgId: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
Profile: &user.SetHumanProfile{
|
||||
FirstName: "Donald",
|
||||
LastName: "Duck",
|
||||
NickName: gu.Ptr("Dukkie"),
|
||||
DisplayName: gu.Ptr("Donald Duck"),
|
||||
PreferredLanguage: gu.Ptr("en"),
|
||||
Gender: user.Gender_GENDER_DIVERSE.Enum(),
|
||||
},
|
||||
Email: &user.SetHumanEmail{
|
||||
Email: "livio@zitadel.com",
|
||||
Verification: &user.SetHumanEmail_IsVerified{
|
||||
IsVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: []*user.SetMetadataEntry{
|
||||
{
|
||||
Key: "somekey",
|
||||
Value: []byte("somevalue"),
|
||||
},
|
||||
},
|
||||
PasswordType: &user.AddHumanUserRequest_Password{
|
||||
Password: &user.Password{
|
||||
Password: "DifficultPW666!",
|
||||
ChangeRequired: false,
|
||||
},
|
||||
},
|
||||
IdpLinks: []*user.IDPLink{
|
||||
{
|
||||
IdpId: idpID,
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &user.AddHumanUserResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -315,3 +475,226 @@ func TestServer_AddHumanUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_AddIDPLink(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.AddIDPLinkRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.AddIDPLinkResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "user does not exist",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddIDPLinkRequest{
|
||||
UserId: "userID",
|
||||
IdpLink: &user.IDPLink{
|
||||
IdpId: idpID,
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "idp does not exist",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddIDPLinkRequest{
|
||||
UserId: Tester.Users[integration.OrgOwner].ID,
|
||||
IdpLink: &user.IDPLink{
|
||||
IdpId: "idpID",
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "add link",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddIDPLinkRequest{
|
||||
UserId: Tester.Users[integration.OrgOwner].ID,
|
||||
IdpLink: &user.IDPLink{
|
||||
IdpId: idpID,
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &user.AddIDPLinkResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.AddIDPLink(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StartIdentityProviderFlow(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.StartIdentityProviderFlowRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.StartIdentityProviderFlowResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing urls",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.StartIdentityProviderFlowRequest{
|
||||
IdpId: idpID,
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "next step auth url",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.StartIdentityProviderFlowRequest{
|
||||
IdpId: idpID,
|
||||
SuccessUrl: "https://example.com/success",
|
||||
FailureUrl: "https://example.com/failure",
|
||||
},
|
||||
},
|
||||
want: &user.StartIdentityProviderFlowResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{
|
||||
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.StartIdentityProviderFlow(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if nextStep := tt.want.GetNextStep(); nextStep != nil {
|
||||
if !strings.HasPrefix(got.GetAuthUrl(), tt.want.GetAuthUrl()) {
|
||||
assert.Failf(t, "auth url does not match", "expected: %s, but got: %s", tt.want.GetAuthUrl(), got.GetAuthUrl())
|
||||
}
|
||||
}
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RetrieveIdentityProviderInformation(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
intentID := createIntent(t, idpID)
|
||||
successfulID, token, changeDate, sequence := createSuccessfulIntent(t, idpID)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.RetrieveIdentityProviderInformationRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.RetrieveIdentityProviderInformationResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "failed intent",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.RetrieveIdentityProviderInformationRequest{
|
||||
IntentId: intentID,
|
||||
Token: "",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong token",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.RetrieveIdentityProviderInformationRequest{
|
||||
IntentId: successfulID,
|
||||
Token: "wrong token",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve successful intent",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.RetrieveIdentityProviderInformationRequest{
|
||||
IntentId: successfulID,
|
||||
Token: token,
|
||||
},
|
||||
},
|
||||
want: &user.RetrieveIdentityProviderInformationResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.New(changeDate),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
Sequence: sequence,
|
||||
},
|
||||
IdpInformation: &user.IDPInformation{
|
||||
Access: &user.IDPInformation_Oauth{
|
||||
Oauth: &user.IDPOAuthAccessInformation{
|
||||
AccessToken: "accessToken",
|
||||
IdToken: gu.Ptr("idToken"),
|
||||
},
|
||||
},
|
||||
IdpInformation: []byte(`{"RawInfo":{"id":"id"}}`),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.RetrieveIdentityProviderInformation(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, tt.want.GetDetails(), got.GetDetails())
|
||||
require.Equal(t, tt.want.GetIdpInformation(), got.GetIdpInformation())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -3,11 +3,21 @@ package user
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
@@ -78,3 +88,118 @@ func Test_hashedPasswordToCommand(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_intentToIDPInformationPb(t *testing.T) {
|
||||
decryption := func(err error) crypto.EncryptionAlgorithm {
|
||||
mCrypto := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
|
||||
mCrypto.EXPECT().Algorithm().Return("enc")
|
||||
mCrypto.EXPECT().DecryptionKeyIDs().Return([]string{"id"})
|
||||
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(code []byte, keyID string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(code), nil
|
||||
})
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
type args struct {
|
||||
intent *command.IDPIntentWriteModel
|
||||
alg crypto.EncryptionAlgorithm
|
||||
}
|
||||
type res struct {
|
||||
resp *user.RetrieveIdentityProviderInformationResponse
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"decryption invalid key id error",
|
||||
args{
|
||||
intent: &command.IDPIntentWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: "intentID",
|
||||
ProcessedSequence: 123,
|
||||
ResourceOwner: "ro",
|
||||
InstanceID: "instanceID",
|
||||
ChangeDate: time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local),
|
||||
},
|
||||
IDPID: "idpID",
|
||||
IDPUser: []byte(`{"id": "id"}`),
|
||||
IDPAccessToken: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("accessToken"),
|
||||
},
|
||||
IDPIDToken: "idToken",
|
||||
UserID: "userID",
|
||||
State: domain.IDPIntentStateSucceeded,
|
||||
},
|
||||
alg: decryption(caos_errs.ThrowInternal(nil, "id", "invalid key id")),
|
||||
},
|
||||
res{
|
||||
resp: nil,
|
||||
err: caos_errs.ThrowInternal(nil, "id", "invalid key id"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"successful",
|
||||
args{
|
||||
intent: &command.IDPIntentWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: "intentID",
|
||||
ProcessedSequence: 123,
|
||||
ResourceOwner: "ro",
|
||||
InstanceID: "instanceID",
|
||||
ChangeDate: time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local),
|
||||
},
|
||||
IDPID: "idpID",
|
||||
IDPUser: []byte(`{"id": "id"}`),
|
||||
IDPAccessToken: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("accessToken"),
|
||||
},
|
||||
IDPIDToken: "idToken",
|
||||
UserID: "userID",
|
||||
State: domain.IDPIntentStateSucceeded,
|
||||
},
|
||||
alg: decryption(nil),
|
||||
},
|
||||
res{
|
||||
resp: &user.RetrieveIdentityProviderInformationResponse{
|
||||
Details: &object_pb.Details{
|
||||
Sequence: 123,
|
||||
ChangeDate: timestamppb.New(time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local)),
|
||||
ResourceOwner: "ro",
|
||||
},
|
||||
IdpInformation: &user.IDPInformation{
|
||||
Access: &user.IDPInformation_Oauth{
|
||||
Oauth: &user.IDPOAuthAccessInformation{
|
||||
AccessToken: "accessToken",
|
||||
IdToken: gu.Ptr("idToken"),
|
||||
}},
|
||||
IdpInformation: []byte(`{"id": "id"}`),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := intentToIDPInformationPb(tt.args.intent, tt.args.alg)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.resp, got)
|
||||
if tt.res.resp != nil {
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
246
internal/api/idp/idp.go
Normal file
246
internal/api/idp/idp.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
z_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/form"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/github"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/google"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
const (
|
||||
HandlerPrefix = "/idps"
|
||||
callbackPath = "/callback"
|
||||
|
||||
paramIntentID = "id"
|
||||
paramToken = "token"
|
||||
paramUserID = "user"
|
||||
paramError = "error"
|
||||
paramErrorDescription = "error_description"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
commands *command.Commands
|
||||
queries *query.Queries
|
||||
parser *form.Parser
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
callbackURL func(ctx context.Context) string
|
||||
}
|
||||
|
||||
type externalIDPCallbackData struct {
|
||||
State string `schema:"state"`
|
||||
Code string `schema:"code"`
|
||||
Error string `schema:"error"`
|
||||
ErrorDescription string `schema:"error_description"`
|
||||
}
|
||||
|
||||
// CallbackURL generates the instance specific URL to the IDP callback handler
|
||||
func CallbackURL(externalSecure bool) func(ctx context.Context) string {
|
||||
return func(ctx context.Context) string {
|
||||
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + HandlerPrefix + callbackPath
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
commands *command.Commands,
|
||||
queries *query.Queries,
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm,
|
||||
externalSecure bool,
|
||||
instanceInterceptor func(next http.Handler) http.Handler,
|
||||
) http.Handler {
|
||||
h := &Handler{
|
||||
commands: commands,
|
||||
queries: queries,
|
||||
parser: form.NewParser(),
|
||||
encryptionAlgorithm: encryptionAlgorithm,
|
||||
callbackURL: CallbackURL(externalSecure),
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.Use(instanceInterceptor)
|
||||
router.HandleFunc(callbackPath, h.handleCallback)
|
||||
return router
|
||||
}
|
||||
|
||||
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := h.parseCallbackRequest(r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
intent := h.getActiveIntent(w, r, data.State)
|
||||
if intent == nil {
|
||||
// if we didn't get an active intent the error was already handled (either redirected or display directly)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
// the provider might have returned an error
|
||||
if data.Error != "" {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.commands.GetProvider(ctx, intent.IDPID, h.callbackURL(ctx))
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURLErr(w, r, intent, err)
|
||||
return
|
||||
}
|
||||
|
||||
idpUser, idpSession, err := h.fetchIDPUser(ctx, provider, data.Code)
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURLErr(w, r, intent, err)
|
||||
return
|
||||
}
|
||||
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
|
||||
|
||||
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
|
||||
if err != nil {
|
||||
redirectToFailureURLErr(w, r, intent, z_errs.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
|
||||
return
|
||||
}
|
||||
redirectToSuccessURL(w, r, intent, token, userID)
|
||||
}
|
||||
|
||||
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
|
||||
data := new(externalIDPCallbackData)
|
||||
err := h.parser.Parse(r, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if data.State == "" {
|
||||
return nil, z_errs.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getActiveIntent(w http.ResponseWriter, r *http.Request, state string) *command.IDPIntentWriteModel {
|
||||
intent, err := h.commands.GetIntentWriteModel(r.Context(), state, "")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
if intent.State == domain.IDPIntentStateUnspecified {
|
||||
http.Error(w, reason("IDP-Hk38e", "Errors.Intent.NotStarted"), http.StatusBadRequest)
|
||||
return nil
|
||||
}
|
||||
if intent.State != domain.IDPIntentStateStarted {
|
||||
redirectToFailureURL(w, r, intent, "IDP-Sfrgs", "Errors.Intent.NotStarted")
|
||||
return nil
|
||||
}
|
||||
return intent
|
||||
}
|
||||
|
||||
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
|
||||
queries := intent.SuccessURL.Query()
|
||||
queries.Set(paramIntentID, intent.AggregateID)
|
||||
queries.Set(paramToken, token)
|
||||
if userID != "" {
|
||||
queries.Set(paramUserID, userID)
|
||||
}
|
||||
intent.SuccessURL.RawQuery = queries.Encode()
|
||||
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
|
||||
msg := err.Error()
|
||||
var description string
|
||||
zErr := new(z_errs.CaosError)
|
||||
if errors.As(err, &zErr) {
|
||||
msg = zErr.GetID()
|
||||
description = zErr.GetMessage() // TODO: i18n?
|
||||
}
|
||||
redirectToFailureURL(w, r, i, msg, description)
|
||||
}
|
||||
|
||||
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
|
||||
queries := i.FailureURL.Query()
|
||||
queries.Set(paramIntentID, i.AggregateID)
|
||||
queries.Set(paramError, err)
|
||||
queries.Set(paramErrorDescription, description)
|
||||
i.FailureURL.RawQuery = queries.Encode()
|
||||
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provider, code string) (user idp.User, idpTokens idp.Session, err error) {
|
||||
var session idp.Session
|
||||
switch provider := identityProvider.(type) {
|
||||
case *oauth.Provider:
|
||||
session = &oauth.Session{Provider: provider, Code: code}
|
||||
case *openid.Provider:
|
||||
session = &openid.Session{Provider: provider, Code: code}
|
||||
case *azuread.Provider:
|
||||
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||||
case *github.Provider:
|
||||
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||||
case *gitlab.Provider:
|
||||
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||||
case *google.Provider:
|
||||
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||||
case *jwt.Provider, *ldap.Provider:
|
||||
return nil, nil, z_errs.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||||
default:
|
||||
return nil, nil, z_errs.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||||
}
|
||||
|
||||
user, err = session.FetchUser(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, session, nil
|
||||
}
|
||||
|
||||
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
|
||||
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
queries := []query.SearchQuery{
|
||||
idQuery, externalIDQuery,
|
||||
}
|
||||
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(links.Links) != 1 {
|
||||
return "", nil
|
||||
}
|
||||
return links.Links[0].UserID, nil
|
||||
}
|
||||
|
||||
func reason(err, description string) string {
|
||||
if description == "" {
|
||||
return err
|
||||
}
|
||||
return err + ": " + description
|
||||
}
|
220
internal/api/idp/idp_test.go
Normal file
220
internal/api/idp/idp_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
z_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/form"
|
||||
)
|
||||
|
||||
func Test_redirectToSuccessURL(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
userID string
|
||||
token string
|
||||
failureURL string
|
||||
successURL string
|
||||
}
|
||||
type res struct {
|
||||
want string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"redirect",
|
||||
args{
|
||||
id: "id",
|
||||
token: "token",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
},
|
||||
res{
|
||||
"https://example.com/success?id=id&token=token",
|
||||
},
|
||||
},
|
||||
{
|
||||
"redirect with userID",
|
||||
args{
|
||||
id: "id",
|
||||
userID: "user",
|
||||
token: "token",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
},
|
||||
res{
|
||||
"https://example.com/success?id=id&token=token&user=user",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
|
||||
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
|
||||
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
|
||||
|
||||
redirectToSuccessURL(resp, req, wm, tt.args.token, tt.args.userID)
|
||||
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectToFailureURL(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
failureURL string
|
||||
successURL string
|
||||
err string
|
||||
desc string
|
||||
}
|
||||
type res struct {
|
||||
want string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"redirect",
|
||||
args{
|
||||
id: "id",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
},
|
||||
res{
|
||||
"https://example.com/failure?error=&error_description=&id=id",
|
||||
},
|
||||
},
|
||||
{
|
||||
"redirect with error",
|
||||
args{
|
||||
id: "id",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
err: "test",
|
||||
desc: "testdesc",
|
||||
},
|
||||
res{
|
||||
"https://example.com/failure?error=test&error_description=testdesc&id=id",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
|
||||
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
|
||||
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
|
||||
|
||||
redirectToFailureURL(resp, req, wm, tt.args.err, tt.args.desc)
|
||||
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectToFailureURLErr(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
failureURL string
|
||||
successURL string
|
||||
err error
|
||||
}
|
||||
type res struct {
|
||||
want string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"redirect with error",
|
||||
args{
|
||||
id: "id",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
err: z_errors.ThrowError(nil, "test", "testdesc"),
|
||||
},
|
||||
res{
|
||||
"https://example.com/failure?error=test&error_description=testdesc&id=id",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
|
||||
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
|
||||
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
|
||||
|
||||
redirectToFailureURLErr(resp, req, wm, tt.args.err)
|
||||
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseCallbackRequest(t *testing.T) {
|
||||
type args struct {
|
||||
url string
|
||||
}
|
||||
type res struct {
|
||||
want *externalIDPCallbackData
|
||||
err bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"no state",
|
||||
args{
|
||||
url: "https://example.com?state=&code=code&error=error&error_description=desc",
|
||||
},
|
||||
res{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"parse",
|
||||
args{
|
||||
url: "https://example.com?state=state&code=code&error=error&error_description=desc",
|
||||
},
|
||||
res{
|
||||
want: &externalIDPCallbackData{
|
||||
State: "state",
|
||||
Code: "code",
|
||||
Error: "error",
|
||||
ErrorDescription: "desc",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tt.args.url, nil)
|
||||
handler := Handler{parser: form.NewParser()}
|
||||
|
||||
data, err := handler.parseCallbackRequest(req)
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.res.want, data)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user