package query import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "regexp" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/v1/models" ) const ( expectedDeviceAuthQueryC = `SELECT` + ` projections.device_authorizations.id,` + ` projections.device_authorizations.client_id,` + ` projections.device_authorizations.scopes,` + ` projections.device_authorizations.expires,` + ` projections.device_authorizations.state,` + ` projections.device_authorizations.subject` + ` FROM projections.device_authorizations` expectedDeviceAuthWhereDeviceCodeQueryC = expectedDeviceAuthQueryC + ` WHERE projections.device_authorizations.client_id = $1` + ` AND projections.device_authorizations.device_code = $2` + ` AND projections.device_authorizations.instance_id = $3` expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC + ` WHERE projections.device_authorizations.instance_id = $1` + ` AND projections.device_authorizations.user_code = $2` ) var ( expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC) expectedDeviceAuthWhereDeviceCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereDeviceCodeQueryC) expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC) expectedDeviceAuthValues = []driver.Value{ "primary-id", "client-id", database.StringArray{"a", "b", "c"}, testNow, domain.DeviceAuthStateApproved, "subject", } expectedDeviceAuth = &domain.DeviceAuth{ ObjectRoot: models.ObjectRoot{ AggregateID: "primary-id", }, ClientID: "client-id", Scopes: []string{"a", "b", "c"}, Expires: testNow, State: domain.DeviceAuthStateApproved, Subject: "subject", } ) func TestQueries_DeviceAuthByDeviceCode(t *testing.T) { client, mock, err := sqlmock.New() if err != nil { t.Fatalf("failed to build mock client: %v", err) } defer client.Close() mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows( sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), ) q := Queries{ client: &database.DB{DB: client}, } got, err := q.DeviceAuthByDeviceCode(context.TODO(), "123", "456") require.NoError(t, err) assert.Equal(t, expectedDeviceAuth, got) require.NoError(t, mock.ExpectationsWereMet()) } func TestQueries_DeviceAuthByUserCode(t *testing.T) { client, mock, err := sqlmock.New() if err != nil { t.Fatalf("failed to build mock client: %v", err) } defer client.Close() mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows( sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), ) q := Queries{ client: &database.DB{DB: client}, } got, err := q.DeviceAuthByUserCode(context.TODO(), "789") require.NoError(t, err) assert.Equal(t, expectedDeviceAuth, got) require.NoError(t, mock.ExpectationsWereMet()) } func Test_prepareDeviceAuthQuery(t *testing.T) { type want struct { sqlExpectations sqlExpectation err checkErr } tests := []struct { name string want want object any }{ { name: "success", want: want{ sqlExpectations: mockQueries( expectedDeviceAuthQuery, deviceAuthSelectColumns, [][]driver.Value{expectedDeviceAuthValues}, ), }, object: expectedDeviceAuth, }, { name: "not found error", want: want{ sqlExpectations: mockQueryErr( expectedDeviceAuthQuery, sql.ErrNoRows, ), err: func(err error) (error, bool) { if !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("err should be sql.ErrNoRows got: %w", err), false } return nil, true }, }, }, { name: "other error", want: want{ sqlExpectations: mockQueryErr( expectedDeviceAuthQuery, sql.ErrConnDone, ), err: func(err error) (error, bool) { if !errors.Is(err, sql.ErrConnDone) { return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false } return nil, true }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) }) } }