feat(queries): login policy idp links (#2767)

* fix(idp): set type in projection

* correct table

* user idp links

* refactor: user idp link query

* add not null constraint

* refactor: idp user links

* rename file

* fix(idp): correct resource owner

* refactor: rename test

* fix(query): implement idp login policy links

* unify naming of idp links

* test prepare

* fix(api): convert idp type

* rename migration
This commit is contained in:
Silvan
2021-12-08 14:49:19 +01:00
committed by GitHub
parent 7bf7379a05
commit c9face4ea4
15 changed files with 336 additions and 89 deletions

View File

@@ -0,0 +1,135 @@
package query
import (
"context"
"database/sql"
sq "github.com/Masterminds/squirrel"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query/projection"
)
type IDPLoginPolicyLink struct {
IDPID string
IDPName string
IDPType domain.IDPConfigType
}
type IDPLoginPolicyLinks struct {
SearchResponse
Links []*IDPLoginPolicyLink
}
type IDPLoginPolicyLinksSearchQuery struct {
SearchRequest
Queries []SearchQuery
}
func (q *IDPLoginPolicyLinksSearchQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
query = q.toQuery(query)
}
return query
}
var (
idpLoginPolicyLinkTable = table{
name: projection.IDPLoginPolicyLinkTable,
}
IDPLoginPolicyLinkIDPIDCol = Column{
name: projection.IDPLoginPolicyLinkIDPIDCol,
table: idpLoginPolicyLinkTable,
}
IDPLoginPolicyLinkCreationDateCol = Column{
name: projection.IDPLoginPolicyLinkCreationDateCol,
table: idpLoginPolicyLinkTable,
}
IDPLoginPolicyLinkChangeDateCol = Column{
name: projection.IDPLoginPolicyLinkChangeDateCol,
table: idpLoginPolicyLinkTable,
}
IDPLoginPolicyLinkSequenceCol = Column{
name: projection.IDPLoginPolicyLinkSequenceCol,
table: idpLoginPolicyLinkTable,
}
IDPLoginPolicyLinkResourceOwnerCol = Column{
name: projection.IDPLoginPolicyLinkResourceOwnerCol,
table: idpLoginPolicyLinkTable,
}
IDPLoginPolicyLinkProviderTypeCol = Column{
name: projection.IDPLoginPolicyLinkProviderTypeCol,
table: idpLoginPolicyLinkTable,
}
)
func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string, queries *IDPLoginPolicyLinksSearchQuery) (idps *IDPLoginPolicyLinks, err error) {
query, scan := prepareIDPLoginPolicyLinksQuery()
stmt, args, err := queries.toQuery(query).Where(
sq.Eq{IDPLoginPolicyLinkResourceOwnerCol.identifier(): resourceOwner},
).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-FDbKW", "Errors.Query.InvalidRequest")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-ZkKUc", "Errors.Internal")
}
idps, err = scan(rows)
if err != nil {
return nil, err
}
idps.LatestSequence, err = q.latestSequence(ctx, idpLoginPolicyLinkTable)
return idps, err
}
func prepareIDPLoginPolicyLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) {
return sq.Select(
IDPLoginPolicyLinkIDPIDCol.identifier(),
IDPNameCol.identifier(),
IDPTypeCol.identifier(),
countColumn.identifier()).
From(idpLoginPolicyLinkTable.identifier()).
LeftJoin(join(IDPIDCol, IDPLoginPolicyLinkIDPIDCol)).PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*IDPLoginPolicyLinks, error) {
links := make([]*IDPLoginPolicyLink, 0)
var count uint64
for rows.Next() {
var (
idpName = sql.NullString{}
idpType = sql.NullInt16{}
link = new(IDPLoginPolicyLink)
)
err := rows.Scan(
&link.IDPID,
&idpName,
&idpType,
&count,
)
if err != nil {
return nil, err
}
link.IDPName = idpName.String
//IDPType 0 is oidc so we have to set unspecified manually
if idpType.Valid {
link.IDPType = domain.IDPConfigType(idpType.Int16)
} else {
link.IDPType = domain.IDPConfigTypeUnspecified
}
links = append(links, link)
}
if err := rows.Close(); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-vOLFG", "Errors.Query.CloseRows")
}
return &IDPLoginPolicyLinks{
Links: links,
SearchResponse: SearchResponse{
Count: count,
},
}, nil
}
}

View File

@@ -0,0 +1,121 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"github.com/caos/zitadel/internal/domain"
)
var (
loginPolicyIDPLinksQuery = regexp.QuoteMeta(`SELECT zitadel.projections.idp_login_policy_links.idp_id,` +
` zitadel.projections.idps.name,` +
` zitadel.projections.idps.type,` +
` COUNT(*) OVER ()` +
` FROM zitadel.projections.idp_login_policy_links` +
` LEFT JOIN zitadel.projections.idps ON zitadel.projections.idp_login_policy_links.idp_id = zitadel.projections.idps.id`)
loginPolicyIDPLinksCols = []string{
"idp_id",
"name",
"type",
"count",
}
)
func Test_IDPLoginPolicyLinkPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareIDPsQuery found",
prepare: prepareIDPLoginPolicyLinksQuery,
want: want{
sqlExpectations: mockQueries(
loginPolicyIDPLinksQuery,
loginPolicyIDPLinksCols,
[][]driver.Value{
{
"idp-id",
"idp-name",
domain.IDPConfigTypeJWT,
},
},
),
},
object: &IDPLoginPolicyLinks{
SearchResponse: SearchResponse{
Count: 1,
},
Links: []*IDPLoginPolicyLink{
{
IDPID: "idp-id",
IDPName: "idp-name",
IDPType: domain.IDPConfigTypeJWT,
},
},
},
},
{
name: "prepareIDPsQuery no idp",
prepare: prepareIDPLoginPolicyLinksQuery,
want: want{
sqlExpectations: mockQueries(
loginPolicyIDPLinksQuery,
loginPolicyIDPLinksCols,
[][]driver.Value{
{
"idp-id",
nil,
nil,
},
},
),
},
object: &IDPLoginPolicyLinks{
SearchResponse: SearchResponse{
Count: 1,
},
Links: []*IDPLoginPolicyLink{
{
IDPID: "idp-id",
IDPName: "",
IDPType: domain.IDPConfigTypeUnspecified,
},
},
},
},
{
name: "prepareIDPsQuery sql err",
prepare: prepareIDPLoginPolicyLinksQuery,
want: want{
sqlExpectations: mockQueryErr(
loginPolicyIDPLinksQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
})
}
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/caos/zitadel/internal/query/projection"
)
type UserIDPLink struct {
type IDPUserLink struct {
IDPID string
UserID string
IDPName string
@@ -19,17 +19,17 @@ type UserIDPLink struct {
IDPType domain.IDPConfigType
}
type UserIDPLinks struct {
type IDPUserLinks struct {
SearchResponse
Links []*UserIDPLink
Links []*IDPUserLink
}
type UserIDPLinksSearchQuery struct {
type IDPUserLinksSearchQuery struct {
SearchRequest
Queries []SearchQuery
}
func (q *UserIDPLinksSearchQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
func (q *IDPUserLinksSearchQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
query = q.toQuery(query)
@@ -75,8 +75,8 @@ var (
}
)
func (q *Queries) UserIDPLinks(ctx context.Context, queries *UserIDPLinksSearchQuery) (idps *UserIDPLinks, err error) {
query, scan := prepareUserIDPLinksQuery()
func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery) (idps *IDPUserLinks, err error) {
query, scan := prepareIDPUserLinksQuery()
stmt, args, err := queries.toQuery(query).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-4zzFK", "Errors.Query.InvalidRequest")
@@ -94,15 +94,15 @@ func (q *Queries) UserIDPLinks(ctx context.Context, queries *UserIDPLinksSearchQ
return idps, err
}
func NewUserIDPLinksUserIDSearchQuery(value string) (SearchQuery, error) {
func NewIDPUserLinksUserIDSearchQuery(value string) (SearchQuery, error) {
return NewTextQuery(IDPUserLinkUserIDCol, value, TextEquals)
}
func NewUserIDPLinksResourceOwnerSearchQuery(value string) (SearchQuery, error) {
func NewIDPUserLinksResourceOwnerSearchQuery(value string) (SearchQuery, error) {
return NewTextQuery(IDPUserLinkResourceOwnerCol, value, TextEquals)
}
func prepareUserIDPLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserIDPLinks, error)) {
func prepareIDPUserLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPUserLinks, error)) {
return sq.Select(
IDPUserLinkIDPIDCol.identifier(),
IDPUserLinkUserIDCol.identifier(),
@@ -113,14 +113,14 @@ func prepareUserIDPLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserIDPLink
countColumn.identifier()).
From(idpUserLinkTable.identifier()).
LeftJoin(join(IDPIDCol, IDPUserLinkIDPIDCol)).PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*UserIDPLinks, error) {
idps := make([]*UserIDPLink, 0)
func(rows *sql.Rows) (*IDPUserLinks, error) {
idps := make([]*IDPUserLink, 0)
var count uint64
for rows.Next() {
var (
idpName = sql.NullString{}
idpType = sql.NullInt16{}
idp = new(UserIDPLink)
idp = new(IDPUserLink)
)
err := rows.Scan(
&idp.IDPID,
@@ -148,7 +148,7 @@ func prepareUserIDPLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserIDPLink
return nil, errors.ThrowInternal(err, "QUERY-nwx6U", "Errors.Query.CloseRows")
}
return &UserIDPLinks{
return &IDPUserLinks{
Links: idps,
SearchResponse: SearchResponse{
Count: count,

View File

@@ -12,7 +12,7 @@ import (
)
var (
userIDPLinksQuery = regexp.QuoteMeta(`SELECT zitadel.projections.idp_user_links.idp_id,` +
idpUserLinksQuery = regexp.QuoteMeta(`SELECT zitadel.projections.idp_user_links.idp_id,` +
` zitadel.projections.idp_user_links.user_id,` +
` zitadel.projections.idps.name,` +
` zitadel.projections.idp_user_links.external_user_id,` +
@@ -21,7 +21,7 @@ var (
` COUNT(*) OVER ()` +
` FROM zitadel.projections.idp_user_links` +
` LEFT JOIN zitadel.projections.idps ON zitadel.projections.idp_user_links.idp_id = zitadel.projections.idps.id`)
userIDPLinksCols = []string{
idpUserLinksCols = []string{
"idp_id",
"user_id",
"name",
@@ -32,7 +32,7 @@ var (
}
)
func Test_UserIDPLinkPrepares(t *testing.T) {
func Test_IDPUserLinkPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
@@ -45,11 +45,11 @@ func Test_UserIDPLinkPrepares(t *testing.T) {
}{
{
name: "prepareIDPsQuery found",
prepare: prepareUserIDPLinksQuery,
prepare: prepareIDPUserLinksQuery,
want: want{
sqlExpectations: mockQueries(
userIDPLinksQuery,
userIDPLinksCols,
idpUserLinksQuery,
idpUserLinksCols,
[][]driver.Value{
{
"idp-id",
@@ -62,11 +62,11 @@ func Test_UserIDPLinkPrepares(t *testing.T) {
},
),
},
object: &UserIDPLinks{
object: &IDPUserLinks{
SearchResponse: SearchResponse{
Count: 1,
},
Links: []*UserIDPLink{
Links: []*IDPUserLink{
{
IDPID: "idp-id",
UserID: "user-id",
@@ -80,11 +80,11 @@ func Test_UserIDPLinkPrepares(t *testing.T) {
},
{
name: "prepareIDPsQuery no idp",
prepare: prepareUserIDPLinksQuery,
prepare: prepareIDPUserLinksQuery,
want: want{
sqlExpectations: mockQueries(
userIDPLinksQuery,
userIDPLinksCols,
idpUserLinksQuery,
idpUserLinksCols,
[][]driver.Value{
{
"idp-id",
@@ -97,11 +97,11 @@ func Test_UserIDPLinkPrepares(t *testing.T) {
},
),
},
object: &UserIDPLinks{
object: &IDPUserLinks{
SearchResponse: SearchResponse{
Count: 1,
},
Links: []*UserIDPLink{
Links: []*IDPUserLink{
{
IDPID: "idp-id",
UserID: "user-id",
@@ -115,10 +115,10 @@ func Test_UserIDPLinkPrepares(t *testing.T) {
},
{
name: "prepareIDPsQuery sql err",
prepare: prepareUserIDPLinksQuery,
prepare: prepareIDPUserLinksQuery,
want: want{
sqlExpectations: mockQueryErr(
userIDPLinksQuery,
idpUserLinksQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {