314 lines
6.8 KiB
Go
Raw Normal View History

package database
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"io"
"reflect"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/static"
)
var (
testNow = time.Now()
)
const (
createObjectStmt = "INSERT INTO system.assets" +
" (instance_id,resource_owner,name,asset_type,content_type,data,updated_at)" +
" VALUES ($1,$2,$3,$4,$5,$6,$7)" +
" ON CONFLICT (instance_id, resource_owner, name) DO UPDATE SET" +
" content_type = $5, data = $6" +
" RETURNING hash"
removeObjectStmt = "DELETE FROM system.assets" +
" WHERE instance_id = $1" +
" AND name = $2" +
" AND resource_owner = $3"
removeObjectsStmt = "DELETE FROM system.assets" +
" WHERE asset_type = $1" +
" AND instance_id = $2" +
" AND resource_owner = $3"
)
func Test_crdbStorage_CreateObject(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
location string
resourceOwner string
name string
contentType string
objectType static.ObjectType
data io.Reader
objectSize int64
}
tests := []struct {
name string
fields fields
args args
want *static.Asset
wantErr bool
}{
{
"create ok",
fields{
client: prepareDB(t,
expectQuery(
createObjectStmt,
[]string{
"hash",
"updated_at",
},
[][]driver.Value{
{
"md5Hash",
testNow,
},
},
"instanceID",
"resourceOwner",
"name",
static.ObjectTypeUserAvatar.String(),
"contentType",
[]byte("test"),
"now()",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
location: "location",
resourceOwner: "resourceOwner",
name: "name",
contentType: "contentType",
objectType: static.ObjectTypeUserAvatar,
data: bytes.NewReader([]byte("test")),
objectSize: 4,
},
&static.Asset{
InstanceID: "instanceID",
Name: "name",
Hash: "md5Hash",
Size: 4,
LastModified: testNow,
Location: "location",
ContentType: "contentType",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &crdbStorage{
client: tt.fields.client.db,
}
got, err := c.PutObject(tt.args.ctx, tt.args.instanceID, tt.args.location, tt.args.resourceOwner, tt.args.name, tt.args.contentType, tt.args.objectType, tt.args.data, tt.args.objectSize)
if (err != nil) != tt.wantErr {
t.Errorf("CreateObject() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateObject() got = %v, want %v", got, tt.want)
}
})
}
}
func Test_crdbStorage_RemoveObject(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
resourceOwner string
name string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeObjectStmt,
nil,
"instanceID",
"name",
"resourceOwner",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
resourceOwner: "resourceOwner",
name: "name",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &crdbStorage{
client: tt.fields.client.db,
}
err := c.RemoveObject(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveObject() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_crdbStorage_RemoveObjects(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
resourceOwner string
objectType static.ObjectType
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeObjectsStmt,
nil, static.ObjectTypeUserAvatar.String(),
"instanceID",
"resourceOwner",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
resourceOwner: "resourceOwner",
objectType: static.ObjectTypeUserAvatar,
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &crdbStorage{
client: tt.fields.client.db,
}
err := c.RemoveObjects(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.objectType)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveObjects() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
type db struct {
mock sqlmock.Sqlmock
db *sql.DB
}
func prepareDB(t *testing.T, expectations ...expectation) db {
t.Helper()
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("unable to create sql mock: %v", err)
}
for _, expectation := range expectations {
expectation(mock)
}
return db{
mock: mock,
db: client,
}
}
type expectation func(m sqlmock.Sqlmock)
func expectExists(query string, value bool, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(value))
}
}
func expectQueryErr(query string, err error, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
}
}
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
}
}
func expectExec(stmt string, err error, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
if err != nil {
query.WillReturnError(err)
return
}
query.WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectBegin(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectBegin()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectCommit(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectCommit()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectRollback(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectRollback()
if err != nil {
query.WillReturnError(err)
}
}
}