mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:37:32 +00:00
feat: store assets in database (#3290)
* feat: use database as asset storage * being only uploading assets if allowed * tests * fixes * cleanup after merge * renaming * various fixes * fix: change to repository event types and removed unused code * feat: set default features * error handling * error handling and naming * fix tests * fix tests * fix merge * rename
This commit is contained in:
182
internal/static/database/crdb.go
Normal file
182
internal/static/database/crdb.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
errs "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
caos_errors "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/static"
|
||||
)
|
||||
|
||||
var _ static.Storage = (*crdbStorage)(nil)
|
||||
|
||||
const (
|
||||
assetsTable = "system.assets"
|
||||
AssetColInstanceID = "instance_id"
|
||||
AssetColType = "asset_type"
|
||||
AssetColLocation = "location"
|
||||
AssetColResourceOwner = "resource_owner"
|
||||
AssetColName = "name"
|
||||
AssetColData = "data"
|
||||
AssetColContentType = "content_type"
|
||||
AssetColHash = "hash"
|
||||
AssetColUpdatedAt = "updated_at"
|
||||
)
|
||||
|
||||
type crdbStorage struct {
|
||||
client *sql.DB
|
||||
}
|
||||
|
||||
func NewStorage(client *sql.DB, _ map[string]interface{}) (static.Storage, error) {
|
||||
return &crdbStorage{client: client}, nil
|
||||
}
|
||||
|
||||
func (c *crdbStorage) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) {
|
||||
data, err := io.ReadAll(object)
|
||||
if err != nil {
|
||||
return nil, caos_errors.ThrowInternal(err, "DATAB-Dfwvq", "Errors.Internal")
|
||||
}
|
||||
stmt, args, err := squirrel.Insert(assetsTable).
|
||||
Columns(AssetColInstanceID, AssetColResourceOwner, AssetColName, AssetColType, AssetColContentType, AssetColData, AssetColUpdatedAt).
|
||||
Values(instanceID, resourceOwner, name, objectType, contentType, data, "now()").
|
||||
Suffix(fmt.Sprintf(
|
||||
"ON CONFLICT (%s, %s, %s) DO UPDATE"+
|
||||
" SET %s = $5, %s = $6"+
|
||||
" RETURNING %s, %s", AssetColInstanceID, AssetColResourceOwner, AssetColName, AssetColContentType, AssetColData, AssetColHash, AssetColUpdatedAt)).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, caos_errors.ThrowInternal(err, "DATAB-32DG1", "Errors.Internal")
|
||||
}
|
||||
var hash string
|
||||
var updatedAt time.Time
|
||||
err = c.client.QueryRowContext(ctx, stmt, args...).Scan(&hash, &updatedAt)
|
||||
if err != nil {
|
||||
return nil, caos_errors.ThrowInternal(err, "DATAB-D2g2q", "Errors.Internal")
|
||||
}
|
||||
return &static.Asset{
|
||||
InstanceID: instanceID,
|
||||
Name: name,
|
||||
Hash: hash,
|
||||
Size: objectSize,
|
||||
LastModified: updatedAt,
|
||||
Location: location,
|
||||
ContentType: contentType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *crdbStorage) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) {
|
||||
query, args, err := squirrel.Select(AssetColData, AssetColContentType, AssetColHash, AssetColUpdatedAt).
|
||||
From(assetsTable).
|
||||
Where(squirrel.Eq{
|
||||
AssetColInstanceID: instanceID,
|
||||
AssetColResourceOwner: resourceOwner,
|
||||
AssetColName: name,
|
||||
}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, nil, caos_errors.ThrowInternal(err, "DATAB-GE3hz", "Errors.Internal")
|
||||
}
|
||||
var data []byte
|
||||
asset := &static.Asset{
|
||||
InstanceID: instanceID,
|
||||
ResourceOwner: resourceOwner,
|
||||
Name: name,
|
||||
}
|
||||
err = c.client.QueryRowContext(ctx, query, args...).
|
||||
Scan(
|
||||
&data,
|
||||
&asset.ContentType,
|
||||
&asset.Hash,
|
||||
&asset.LastModified,
|
||||
)
|
||||
if err != nil {
|
||||
if errs.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, caos_errors.ThrowNotFound(err, "DATAB-pCP8P", "Errors.Assets.Object.NotFound")
|
||||
}
|
||||
return nil, nil, caos_errors.ThrowInternal(err, "DATAB-Sfgb3", "Errors.Assets.Object.GetFailed")
|
||||
}
|
||||
asset.Size = int64(len(data))
|
||||
return data,
|
||||
func() (*static.Asset, error) {
|
||||
return asset, nil
|
||||
},
|
||||
nil
|
||||
}
|
||||
|
||||
func (c *crdbStorage) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) {
|
||||
query, args, err := squirrel.Select(AssetColContentType, AssetColLocation, "length("+AssetColData+")", AssetColHash, AssetColUpdatedAt).
|
||||
From(assetsTable).
|
||||
Where(squirrel.Eq{
|
||||
AssetColInstanceID: instanceID,
|
||||
AssetColResourceOwner: resourceOwner,
|
||||
AssetColName: name,
|
||||
}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, caos_errors.ThrowInternal(err, "DATAB-rggt2", "Errors.Internal")
|
||||
}
|
||||
asset := &static.Asset{
|
||||
InstanceID: instanceID,
|
||||
ResourceOwner: resourceOwner,
|
||||
Name: name,
|
||||
}
|
||||
err = c.client.QueryRowContext(ctx, query, args...).
|
||||
Scan(
|
||||
&asset.ContentType,
|
||||
&asset.Location,
|
||||
&asset.Size,
|
||||
&asset.Hash,
|
||||
&asset.LastModified,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, caos_errors.ThrowInternal(err, "DATAB-Dbh2s", "Errors.Internal")
|
||||
}
|
||||
return asset, nil
|
||||
}
|
||||
|
||||
func (c *crdbStorage) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error {
|
||||
stmt, args, err := squirrel.Delete(assetsTable).
|
||||
Where(squirrel.Eq{
|
||||
AssetColInstanceID: instanceID,
|
||||
AssetColResourceOwner: resourceOwner,
|
||||
AssetColName: name,
|
||||
}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "DATAB-Sgvwq", "Errors.Internal")
|
||||
}
|
||||
_, err = c.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "DATAB-RHNgf", "Errors.Assets.Object.RemoveFailed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *crdbStorage) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error {
|
||||
stmt, args, err := squirrel.Delete(assetsTable).
|
||||
Where(squirrel.Eq{
|
||||
AssetColInstanceID: instanceID,
|
||||
AssetColResourceOwner: resourceOwner,
|
||||
AssetColType: objectType,
|
||||
}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "DATAB-Sfgeq", "Errors.Internal")
|
||||
}
|
||||
_, err = c.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "DATAB-Efgt2", "Errors.Assets.Object.RemoveFailed")
|
||||
}
|
||||
return nil
|
||||
}
|
203
internal/static/database/crdb_test.go
Normal file
203
internal/static/database/crdb_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/caos/zitadel/internal/static"
|
||||
)
|
||||
|
||||
var (
|
||||
testNow = time.Now()
|
||||
)
|
||||
|
||||
const (
|
||||
objectStmt = "INSERT INTO system.assets" +
|
||||
" (instance_id,resource_owner,name,asset_type,content_type,data,updated_at)" +
|
||||
" VALUES ($1,$2,$3,$4,$5,$6,$7)" +
|
||||
" ON CONFLICT (instance_id, resource_owner, name) DO UPDATE SET" +
|
||||
" content_type = $5, data = $6" +
|
||||
" RETURNING hash"
|
||||
)
|
||||
|
||||
func Test_crdbStorage_CreateObject(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
instanceID string
|
||||
location string
|
||||
resourceOwner string
|
||||
name string
|
||||
contentType string
|
||||
objectType static.ObjectType
|
||||
data io.Reader
|
||||
objectSize int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *static.Asset
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"create ok",
|
||||
fields{
|
||||
client: prepareDB(t,
|
||||
expectQuery(
|
||||
objectStmt,
|
||||
[]string{
|
||||
"hash",
|
||||
"updated_at",
|
||||
},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"md5Hash",
|
||||
testNow,
|
||||
},
|
||||
},
|
||||
"instanceID",
|
||||
"resourceOwner",
|
||||
"name",
|
||||
static.ObjectTypeUserAvatar,
|
||||
"contentType",
|
||||
[]byte("test"),
|
||||
"now()",
|
||||
)),
|
||||
},
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
instanceID: "instanceID",
|
||||
location: "location",
|
||||
resourceOwner: "resourceOwner",
|
||||
name: "name",
|
||||
contentType: "contentType",
|
||||
data: bytes.NewReader([]byte("test")),
|
||||
objectSize: 4,
|
||||
},
|
||||
&static.Asset{
|
||||
InstanceID: "instanceID",
|
||||
Name: "name",
|
||||
Hash: "md5Hash",
|
||||
Size: 4,
|
||||
LastModified: testNow,
|
||||
Location: "location",
|
||||
ContentType: "contentType",
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &crdbStorage{
|
||||
client: tt.fields.client.db,
|
||||
}
|
||||
got, err := c.PutObject(tt.args.ctx, tt.args.instanceID, tt.args.location, tt.args.resourceOwner, tt.args.name, tt.args.contentType, tt.args.objectType, tt.args.data, tt.args.objectSize)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CreateObject() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateObject() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type db struct {
|
||||
mock sqlmock.Sqlmock
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func prepareDB(t *testing.T, expectations ...expectation) db {
|
||||
t.Helper()
|
||||
client, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create sql mock: %v", err)
|
||||
}
|
||||
for _, expectation := range expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
return db{
|
||||
mock: mock,
|
||||
db: client,
|
||||
}
|
||||
}
|
||||
|
||||
type expectation func(m sqlmock.Sqlmock)
|
||||
|
||||
func expectExists(query string, value bool, args ...driver.Value) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(value))
|
||||
}
|
||||
}
|
||||
|
||||
func expectQueryErr(query string, err error, args ...driver.Value) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
|
||||
}
|
||||
}
|
||||
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
result := sqlmock.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
q.RowsWillBeClosed()
|
||||
}
|
||||
}
|
||||
|
||||
func expectExec(stmt string, err error, args ...driver.Value) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
return
|
||||
}
|
||||
query.WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectBegin(err error) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectBegin()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommit(err error) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectCommit()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectRollback(err error) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectRollback()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user