diff --git a/go.mod b/go.mod index 1c45ab9104..a7f44ab4f3 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/mitchellh/copystructure v1.2.0 // indirect github.com/muesli/gamut v0.2.0 github.com/nicksnyder/go-i18n/v2 v2.1.2 + github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 // indirect github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.3.0 github.com/rakyll/statik v0.1.7 diff --git a/go.sum b/go.sum index 21c66b16b6..a042afe9df 100644 --- a/go.sum +++ b/go.sum @@ -779,6 +779,8 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS github.com/nicksnyder/go-i18n/v2 v2.1.2 h1:QHYxcUJnGHBaq7XbvgunmZ2Pn0focXFqTD61CkH146c= github.com/nicksnyder/go-i18n/v2 v2.1.2/go.mod h1:d++QJC9ZVf7pa48qrsRWhMJ5pSHIPmS3OLqK1niyLxs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 h1:NHrXEjTNQY7P0Zfx1aMrNhpgxHmow66XQtm0aQLY0AE= +github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249/go.mod h1:mpRZBD8SJ55OIICQ3iWH0Yz3cjzA61JdqMLoWXeB2+8= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= diff --git a/internal/query/prepare_test.go b/internal/query/prepare_test.go new file mode 100644 index 0000000000..fa97424a49 --- /dev/null +++ b/internal/query/prepare_test.go @@ -0,0 +1,324 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "log" + "reflect" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + sq "github.com/Masterminds/squirrel" + "github.com/nsf/jsondiff" +) + +//assertPrepare checks if the prepare func executes the correct sql query and returns the correct object +//prepareFunc must be of type +// func() (sq.SelectBuilder, func(*sql.Rows) (*struct, error)) +// or +// func() (sq.SelectBuilder, func(*sql.Row) (*struct, error)) +//expectedObject represents the return value of scan +//sqlExpectation represents the query executed on the database +func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExpectation sqlExpectation, isErr checkErr) bool { + t.Helper() + + client, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to build mock client: %v", err) + } + + mock = sqlExpectation(mock) + + builder, scan, err := execPrepare(prepareFunc) + if err != nil { + t.Error(err) + return false + } + errCheck := func(err error) (error, bool) { + if isErr == nil { + if err == nil { + return nil, true + } else { + return fmt.Errorf("no error expected got: %w", err), false + } + } + return isErr(err) + } + object, ok := execScan(client, builder, scan, errCheck) + if !ok { + t.Error(object) + return false + } + + if !reflect.DeepEqual(object, expectedObject) { + prettyPrintDiff(t, expectedObject, object) + return false + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sql expectations not met: %v", err) + return false + } + + return true +} + +type checkErr func(error) (err error, ok bool) + +type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock + +func mockQuery(stmt string, cols []string, row []driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + q := m.ExpectQuery(stmt) + result := sqlmock.NewRows(cols) + result.AddRow(row...) + q.WillReturnRows(result) + return m + } +} + +func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + q := m.ExpectQuery(stmt).WithArgs(args...) + result := sqlmock.NewRows(cols) + count := uint64(len(rows)) + for _, row := range rows { + row = append(row, count) + result.AddRow(row...) + } + q.WillReturnRows(result) + q.RowsWillBeClosed() + return m + } +} + +func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + q := m.ExpectQuery(stmt).WithArgs(args...) + q.WillReturnError(err) + return m + } +} + +var ( + rowType = reflect.TypeOf(&sql.Row{}) + rowsType = reflect.TypeOf(&sql.Rows{}) + selectBuilderType = reflect.TypeOf(sq.SelectBuilder{}) +) + +func execScan(client *sql.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (interface{}, bool) { + scanType := reflect.TypeOf(scan) + err := validateScan(scanType) + if err != nil { + return err, false + } + + stmt, args, err := builder.ToSql() + if err != nil { + return fmt.Errorf("unexpeted error from sql builder: %w", err), false + } + + //resultSet represents *sql.Row or *sql.Rows, + // depending on whats assignable to the scan function + var resultSet interface{} + + //execute sql stmt + // if scan(*sql.Rows)... + if scanType.In(0).AssignableTo(rowsType) { + resultSet, err = client.Query(stmt, args...) + if err != nil { + return errCheck(err) + } + + // if scan(*sql.Row)... + } else if scanType.In(0).AssignableTo(rowType) { + row := client.QueryRow(stmt, args...) + if row.Err() != nil { + return errCheck(row.Err()) + } + resultSet = row + } else { + return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false + } + + // res contains object and error + res := reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(resultSet)}) + + //check for error + if res[1].Interface() != nil { + if err, ok := errCheck(res[1].Interface().(error)); !ok { + return fmt.Errorf("scan failed: %w", err), false + } + } + + return res[0].Interface(), true +} + +func validateScan(scanType reflect.Type) error { + if scanType.Kind() != reflect.Func { + return errors.New("scan is not a function") + } + if scanType.NumIn() != 1 { + return fmt.Errorf("scan: invalid number of inputs: want: 1 got %d", scanType.NumIn()) + } + if scanType.NumOut() != 2 { + return fmt.Errorf("scan: invalid number of outputs: want: 2 got %d", scanType.NumOut()) + } + return nil +} + +func execPrepare(prepare interface{}) (builder sq.SelectBuilder, scan interface{}, err error) { + prepareVal := reflect.ValueOf(prepare) + if err := validatePrepare(prepareVal.Type()); err != nil { + return sq.SelectBuilder{}, nil, err + } + res := prepareVal.Call([]reflect.Value{}) + + return res[0].Interface().(sq.SelectBuilder), res[1].Interface(), nil +} + +func validatePrepare(prepareType reflect.Type) error { + if prepareType.Kind() != reflect.Func { + return errors.New("prepare is not a function") + } + if prepareType.NumIn() != 0 { + return fmt.Errorf("prepare: invalid number of inputs: want: 0 got %d", prepareType.NumIn()) + } + if prepareType.NumOut() != 2 { + return fmt.Errorf("prepare: invalid number of outputs: want: 2 got %d", prepareType.NumOut()) + } + if prepareType.Out(0) != selectBuilderType { + return fmt.Errorf("prepare: first return value must be: %s got %s", selectBuilderType, prepareType.Out(0)) + } + if prepareType.Out(1).Kind() != reflect.Func { + return fmt.Errorf("prepare: second return value must be: %s got %s", reflect.Func, prepareType.Out(1)) + } + return nil +} + +func TestValidateScan(t *testing.T) { + tests := []struct { + name string + t reflect.Type + expectErr bool + }{ + { + name: "not a func", + t: reflect.TypeOf(&struct{}{}), + expectErr: true, + }, + { + name: "wong input count", + t: reflect.TypeOf(func() (*struct{}, error) { + log.Fatal("should not be executed") + return nil, nil + }), + expectErr: true, + }, + { + name: "wrong output count", + t: reflect.TypeOf(func(interface{}) error { + log.Fatal("should not be executed") + return nil + }), + expectErr: true, + }, + { + name: "correct", + t: reflect.TypeOf(func(interface{}) (*struct{}, error) { + log.Fatal("should not be executed") + return nil, nil + }), + expectErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateScan(tt.t) + if (err != nil) != tt.expectErr { + t.Errorf("unexpected err: %v", err) + } + }) + } +} + +func TestValidatePrepare(t *testing.T) { + tests := []struct { + name string + t reflect.Type + expectErr bool + }{ + { + name: "not a func", + t: reflect.TypeOf(&struct{}{}), + expectErr: true, + }, + { + name: "wong input count", + t: reflect.TypeOf(func(int) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) { + log.Fatal("should not be executed") + return sq.SelectBuilder{}, nil + }), + expectErr: true, + }, + { + name: "wrong output count", + t: reflect.TypeOf(func() sq.SelectBuilder { + log.Fatal("should not be executed") + return sq.SelectBuilder{} + }), + expectErr: true, + }, + { + name: "first output type wrong", + t: reflect.TypeOf(func() (*struct{}, func(*sql.Rows) (interface{}, error)) { + log.Fatal("should not be executed") + return nil, nil + }), + expectErr: true, + }, + { + name: "second output type wrong", + t: reflect.TypeOf(func() (sq.SelectBuilder, *struct{}) { + log.Fatal("should not be executed") + return sq.SelectBuilder{}, nil + }), + expectErr: true, + }, + { + name: "correct", + t: reflect.TypeOf(func() (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) { + log.Fatal("should not be executed") + return sq.SelectBuilder{}, nil + }), + expectErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePrepare(tt.t) + if (err != nil) != tt.expectErr { + t.Errorf("unexpected err: %v", err) + } + }) + } +} + +func prettyPrintDiff(t *testing.T, expected, gotten interface{}) { + t.Helper() + + expectedMarshalled, _ := json.Marshal(expected) + objectMarshalled, _ := json.Marshal(gotten) + _, diff := jsondiff.Compare( + expectedMarshalled, + objectMarshalled, + &jsondiff.Options{ + SkipMatches: true, + Indent: " ", + ChangedSeparator: " is expected, got ", + }) + t.Errorf("unexpected object: want %T, got %T, difference:\n%s", expected, gotten, diff) +} diff --git a/internal/query/project.go b/internal/query/project.go index b6d37a3054..449707b249 100644 --- a/internal/query/project.go +++ b/internal/query/project.go @@ -155,7 +155,7 @@ func (r *ProjectSearchQueries) AppendMyResourceOwnerQuery(orgID string) error { return nil } -func (r ProjectSearchQueries) AppendPermissionQueries(permissions []string) error { +func (r *ProjectSearchQueries) AppendPermissionQueries(permissions []string) error { if !authz.HasGlobalPermission(permissions) { ids := authz.GetAllPermissionCtxIDs(permissions) query, err := NewProjectIDSearchQuery(ids) diff --git a/internal/query/project_test.go b/internal/query/project_test.go new file mode 100644 index 0000000000..00a86148a6 --- /dev/null +++ b/internal/query/project_test.go @@ -0,0 +1,372 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + "time" + + "github.com/caos/zitadel/internal/domain" + errs "github.com/caos/zitadel/internal/errors" +) + +var ( + now = time.Now() +) + +func Test_prepareProjectsQuery(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "prepareProjectsQuery no result", + prepare: prepareProjectsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.projects`), + nil, + nil, + ), + }, + object: &Projects{Projects: []*Project{}}, + }, + { + name: "prepareProjectsQuery one result", + prepare: prepareProjectsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.projects`), + []string{ + "id", + "creation_date", + "change_date", + "resource_owner", + "state", + "sequence", + "name", + "project_role_assertion", + "project_role_check", + "has_project_check", + "private_labeling_setting", + "count", + }, + [][]driver.Value{ + { + "id", + now, + now, + "ro", + domain.ProjectStateActive, + uint64(20211108), + "project-name", + true, + true, + true, + domain.PrivateLabelingSettingEnforceProjectResourceOwnerPolicy, + }, + }, + ), + }, + object: &Projects{ + SearchResponse: SearchResponse{ + Count: 1, + }, + Projects: []*Project{ + { + ID: "id", + CreationDate: now, + ChangeDate: now, + ResourceOwner: "ro", + State: domain.ProjectStateActive, + Sequence: 20211108, + Name: "project-name", + ProjectRoleAssertion: true, + ProjectRoleCheck: true, + HasProjectCheck: true, + PrivateLabelingSetting: domain.PrivateLabelingSettingEnforceProjectResourceOwnerPolicy, + }, + }, + }, + }, + { + name: "prepareProjectsQuery multiple result", + prepare: prepareProjectsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.projects`), + []string{ + "id", + "creation_date", + "change_date", + "resource_owner", + "state", + "sequence", + "name", + "project_role_assertion", + "project_role_check", + "has_project_check", + "private_labeling_setting", + "count", + }, + [][]driver.Value{ + { + "id-1", + now, + now, + "ro", + domain.ProjectStateActive, + uint64(20211108), + "project-name-1", + true, + true, + true, + domain.PrivateLabelingSettingEnforceProjectResourceOwnerPolicy, + }, + { + "id-2", + now, + now, + "ro", + domain.ProjectStateActive, + uint64(20211108), + "project-name-2", + false, + false, + false, + domain.PrivateLabelingSettingAllowLoginUserResourceOwnerPolicy, + }, + }, + ), + }, + object: &Projects{ + SearchResponse: SearchResponse{ + Count: 2, + }, + Projects: []*Project{ + { + ID: "id-1", + CreationDate: now, + ChangeDate: now, + ResourceOwner: "ro", + State: domain.ProjectStateActive, + Sequence: 20211108, + Name: "project-name-1", + ProjectRoleAssertion: true, + ProjectRoleCheck: true, + HasProjectCheck: true, + PrivateLabelingSetting: domain.PrivateLabelingSettingEnforceProjectResourceOwnerPolicy, + }, + { + ID: "id-2", + CreationDate: now, + ChangeDate: now, + ResourceOwner: "ro", + State: domain.ProjectStateActive, + Sequence: 20211108, + Name: "project-name-2", + ProjectRoleAssertion: false, + ProjectRoleCheck: false, + HasProjectCheck: false, + PrivateLabelingSetting: domain.PrivateLabelingSettingAllowLoginUserResourceOwnerPolicy, + }, + }, + }, + }, + { + name: "prepareProjectsQuery sql err", + prepare: prepareProjectsQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(`SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.projects`), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: nil, + }, + { + name: "prepareProjectQuery no result", + prepare: prepareProjectQuery, + want: want{ + sqlExpectations: mockQueries( + `SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting`+ + ` FROM zitadel.projections.projects`, + nil, + nil, + ), + err: func(err error) (error, bool) { + if !errs.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: (*Project)(nil), + }, + { + name: "prepareProjectQuery found", + prepare: prepareProjectQuery, + want: want{ + sqlExpectations: mockQuery( + regexp.QuoteMeta(`SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting`+ + ` FROM zitadel.projections.projects`), + []string{ + "id", + "creation_date", + "change_date", + "resource_owner", + "state", + "sequence", + "name", + "project_role_assertion", + "project_role_check", + "has_project_check", + "private_labeling_setting", + }, + []driver.Value{ + "id", + now, + now, + "ro", + domain.ProjectStateActive, + uint64(20211108), + "project-name", + true, + true, + true, + domain.PrivateLabelingSettingEnforceProjectResourceOwnerPolicy, + }, + ), + }, + object: &Project{ + ID: "id", + CreationDate: now, + ChangeDate: now, + ResourceOwner: "ro", + State: domain.ProjectStateActive, + Sequence: 20211108, + Name: "project-name", + ProjectRoleAssertion: true, + ProjectRoleCheck: true, + HasProjectCheck: true, + PrivateLabelingSetting: domain.PrivateLabelingSettingEnforceProjectResourceOwnerPolicy, + }, + }, + { + name: "prepareProjectQuery sql err", + prepare: prepareProjectQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(`SELECT zitadel.projections.projects.id,`+ + ` zitadel.projections.projects.creation_date,`+ + ` zitadel.projections.projects.change_date,`+ + ` zitadel.projections.projects.resource_owner,`+ + ` zitadel.projections.projects.state,`+ + ` zitadel.projections.projects.sequence,`+ + ` zitadel.projections.projects.name,`+ + ` zitadel.projections.projects.project_role_assertion,`+ + ` zitadel.projections.projects.project_role_check,`+ + ` zitadel.projections.projects.has_project_check,`+ + ` zitadel.projections.projects.private_labeling_setting`+ + ` FROM zitadel.projections.projects`), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) + }) + } +}