package query import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "regexp" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/internal/database" db_mock "github.com/zitadel/zitadel/internal/database/mock" "github.com/zitadel/zitadel/internal/domain" ) const ( expectedDeviceAuthQueryC = `SELECT` + ` projections.device_auth_requests2.client_id,` + ` projections.device_auth_requests2.device_code,` + ` projections.device_auth_requests2.user_code,` + ` projections.device_auth_requests2.scopes,` + ` projections.device_auth_requests2.audience` + ` FROM projections.device_auth_requests2` expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC + ` WHERE projections.device_auth_requests2.instance_id = $1` + ` AND projections.device_auth_requests2.user_code = $2` ) var ( expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC) expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC) expectedDeviceAuthValues = []driver.Value{ "client-id", "device1", "user-code", database.TextArray[string]{"a", "b", "c"}, []string{"projectID", "clientID"}, } expectedDeviceAuth = &domain.AuthRequestDevice{ ClientID: "client-id", DeviceCode: "device1", UserCode: "user-code", Scopes: []string{"a", "b", "c"}, Audience: []string{"projectID", "clientID"}, } ) func TestQueries_DeviceAuthRequestByUserCode(t *testing.T) { client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter))) if err != nil { t.Fatalf("failed to build mock client: %v", err) } defer client.Close() mock.ExpectBegin() mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows( mock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), ) mock.ExpectCommit() q := Queries{ client: &database.DB{DB: client}, } got, err := q.DeviceAuthRequestByUserCode(context.TODO(), "789") require.NoError(t, err) assert.Equal(t, expectedDeviceAuth, got) require.NoError(t, mock.ExpectationsWereMet()) } func Test_prepareDeviceAuthQuery(t *testing.T) { type want struct { sqlExpectations sqlExpectation err checkErr } tests := []struct { name string want want object *domain.AuthRequestDevice }{ { name: "success", want: want{ sqlExpectations: mockQueries( expectedDeviceAuthQuery, deviceAuthSelectColumns, [][]driver.Value{expectedDeviceAuthValues}, ), }, object: expectedDeviceAuth, }, { name: "not found error", want: want{ sqlExpectations: mockQueryErr( expectedDeviceAuthQuery, sql.ErrNoRows, ), err: func(err error) (error, bool) { if !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("err should be sql.ErrNoRows got: %w", err), false } return nil, true }, }, object: nil, }, { name: "other error", want: want{ sqlExpectations: mockQueryErr( expectedDeviceAuthQuery, sql.ErrConnDone, ), err: func(err error) (error, bool) { if !errors.Is(err, sql.ErrConnDone) { return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false } return nil, true }, }, object: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) }) } }