zitadel/internal/query/device_auth_test.go

134 lines
3.5 KiB
Go
Raw Normal View History

package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/domain"
)
const (
expectedDeviceAuthQueryC = `SELECT` +
` projections.device_auth_requests2.client_id,` +
` projections.device_auth_requests2.device_code,` +
` projections.device_auth_requests2.user_code,` +
` projections.device_auth_requests2.scopes,` +
` projections.device_auth_requests2.audience` +
` FROM projections.device_auth_requests2`
expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_auth_requests2.instance_id = $1` +
` AND projections.device_auth_requests2.user_code = $2`
)
var (
expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC)
expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC)
expectedDeviceAuthValues = []driver.Value{
"client-id",
"device1",
"user-code",
database.TextArray[string]{"a", "b", "c"},
[]string{"projectID", "clientID"},
}
expectedDeviceAuth = &domain.AuthRequestDevice{
ClientID: "client-id",
DeviceCode: "device1",
UserCode: "user-code",
Scopes: []string{"a", "b", "c"},
Audience: []string{"projectID", "clientID"},
}
)
func TestQueries_DeviceAuthRequestByUserCode(t *testing.T) {
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
}
defer client.Close()
mock.ExpectBegin()
mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows(
mock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
)
mock.ExpectCommit()
q := Queries{
client: &database.DB{DB: client},
}
got, err := q.DeviceAuthRequestByUserCode(context.TODO(), "789")
require.NoError(t, err)
assert.Equal(t, expectedDeviceAuth, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func Test_prepareDeviceAuthQuery(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
want want
object *domain.AuthRequestDevice
}{
{
name: "success",
want: want{
sqlExpectations: mockQueries(
expectedDeviceAuthQuery,
deviceAuthSelectColumns,
[][]driver.Value{expectedDeviceAuthValues},
),
},
object: expectedDeviceAuth,
},
{
name: "not found error",
want: want{
sqlExpectations: mockQueryErr(
expectedDeviceAuthQuery,
sql.ErrNoRows,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("err should be sql.ErrNoRows got: %w", err), false
}
return nil, true
},
},
object: nil,
},
{
name: "other error",
want: want{
sqlExpectations: mockQueryErr(
expectedDeviceAuthQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}