chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,199 @@
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"time"
"github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/static"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ static.Storage = (*storage)(nil)
const (
assetsTable = "system.assets"
AssetColInstanceID = "instance_id"
AssetColType = "asset_type"
AssetColLocation = "location"
AssetColResourceOwner = "resource_owner"
AssetColName = "name"
AssetColData = "data"
AssetColContentType = "content_type"
AssetColHash = "hash"
AssetColUpdatedAt = "updated_at"
)
type storage struct {
client *sql.DB
}
func NewStorage(client *sql.DB, _ map[string]interface{}) (static.Storage, error) {
return &storage{client: client}, nil
}
func (c *storage) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) {
data, err := io.ReadAll(object)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-Dfwvq", "Errors.Internal")
}
stmt, args, err := squirrel.Insert(assetsTable).
Columns(AssetColInstanceID, AssetColResourceOwner, AssetColName, AssetColType, AssetColContentType, AssetColData, AssetColUpdatedAt).
Values(instanceID, resourceOwner, name, objectType.String(), contentType, data, "now()").
Suffix(fmt.Sprintf(
"ON CONFLICT (%s, %s, %s) DO UPDATE"+
" SET %s = $5, %s = $6"+
" RETURNING %s, %s", AssetColInstanceID, AssetColResourceOwner, AssetColName, AssetColContentType, AssetColData, AssetColHash, AssetColUpdatedAt)).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-32DG1", "Errors.Internal")
}
var hash string
var updatedAt time.Time
err = c.client.QueryRowContext(ctx, stmt, args...).Scan(&hash, &updatedAt)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-D2g2q", "Errors.Internal")
}
return &static.Asset{
InstanceID: instanceID,
Name: name,
Hash: hash,
Size: objectSize,
LastModified: updatedAt,
Location: location,
ContentType: contentType,
}, nil
}
func (c *storage) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) {
query, args, err := squirrel.Select(AssetColData, AssetColContentType, AssetColHash, AssetColUpdatedAt).
From(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColName: name,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, nil, zerrors.ThrowInternal(err, "DATAB-GE3hz", "Errors.Internal")
}
var data []byte
asset := &static.Asset{
InstanceID: instanceID,
ResourceOwner: resourceOwner,
Name: name,
}
err = c.client.QueryRowContext(ctx, query, args...).
Scan(
&data,
&asset.ContentType,
&asset.Hash,
&asset.LastModified,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, zerrors.ThrowNotFound(err, "DATAB-pCP8P", "Errors.Assets.Object.NotFound")
}
return nil, nil, zerrors.ThrowInternal(err, "DATAB-Sfgb3", "Errors.Assets.Object.GetFailed")
}
asset.Size = int64(len(data))
return data,
func() (*static.Asset, error) {
return asset, nil
},
nil
}
func (c *storage) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) {
query, args, err := squirrel.Select(AssetColContentType, AssetColLocation, "length("+AssetColData+")", AssetColHash, AssetColUpdatedAt).
From(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColName: name,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-rggt2", "Errors.Internal")
}
asset := &static.Asset{
InstanceID: instanceID,
ResourceOwner: resourceOwner,
Name: name,
}
err = c.client.QueryRowContext(ctx, query, args...).
Scan(
&asset.ContentType,
&asset.Location,
&asset.Size,
&asset.Hash,
&asset.LastModified,
)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-Dbh2s", "Errors.Internal")
}
return asset, nil
}
func (c *storage) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error {
stmt, args, err := squirrel.Delete(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColName: name,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Sgvwq", "Errors.Internal")
}
_, err = c.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-RHNgf", "Errors.Assets.Object.RemoveFailed")
}
return nil
}
func (c *storage) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error {
stmt, args, err := squirrel.Delete(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColType: objectType.String(),
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Sfgeq", "Errors.Internal")
}
_, err = c.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Efgt2", "Errors.Assets.Object.RemoveFailed")
}
return nil
}
func (c *storage) RemoveInstanceObjects(ctx context.Context, instanceID string) error {
stmt, args, err := squirrel.Delete(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Sfgeq", "Errors.Internal")
}
_, err = c.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Efgt2", "Errors.Assets.Object.RemoveFailed")
}
return nil
}

View File

@@ -0,0 +1,360 @@
package database
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"io"
"reflect"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/static"
)
var (
testNow = time.Now()
)
const (
createObjectStmt = "INSERT INTO system.assets" +
" (instance_id,resource_owner,name,asset_type,content_type,data,updated_at)" +
" VALUES ($1,$2,$3,$4,$5,$6,$7)" +
" ON CONFLICT (instance_id, resource_owner, name) DO UPDATE SET" +
" content_type = $5, data = $6" +
" RETURNING hash"
removeObjectStmt = "DELETE FROM system.assets" +
" WHERE instance_id = $1" +
" AND name = $2" +
" AND resource_owner = $3"
removeObjectsStmt = "DELETE FROM system.assets" +
" WHERE asset_type = $1" +
" AND instance_id = $2" +
" AND resource_owner = $3"
removeInstanceObjectsStmt = "DELETE FROM system.assets" +
" WHERE instance_id = $1"
)
func Test_dbStorage_CreateObject(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
location string
resourceOwner string
name string
contentType string
objectType static.ObjectType
data io.Reader
objectSize int64
}
tests := []struct {
name string
fields fields
args args
want *static.Asset
wantErr bool
}{
{
"create ok",
fields{
client: prepareDB(t,
expectQuery(
createObjectStmt,
[]string{
"hash",
"updated_at",
},
[][]driver.Value{
{
"md5Hash",
testNow,
},
},
"instanceID",
"resourceOwner",
"name",
static.ObjectTypeUserAvatar.String(),
"contentType",
[]byte("test"),
"now()",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
location: "location",
resourceOwner: "resourceOwner",
name: "name",
contentType: "contentType",
objectType: static.ObjectTypeUserAvatar,
data: bytes.NewReader([]byte("test")),
objectSize: 4,
},
&static.Asset{
InstanceID: "instanceID",
Name: "name",
Hash: "md5Hash",
Size: 4,
LastModified: testNow,
Location: "location",
ContentType: "contentType",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
got, err := c.PutObject(tt.args.ctx, tt.args.instanceID, tt.args.location, tt.args.resourceOwner, tt.args.name, tt.args.contentType, tt.args.objectType, tt.args.data, tt.args.objectSize)
if (err != nil) != tt.wantErr {
t.Errorf("CreateObject() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateObject() got = %v, want %v", got, tt.want)
}
})
}
}
func Test_dbStorage_RemoveObject(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
resourceOwner string
name string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeObjectStmt,
nil,
"instanceID",
"name",
"resourceOwner",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
resourceOwner: "resourceOwner",
name: "name",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
err := c.RemoveObject(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveObject() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_dbStorage_RemoveObjects(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
resourceOwner string
objectType static.ObjectType
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeObjectsStmt,
nil, static.ObjectTypeUserAvatar.String(),
"instanceID",
"resourceOwner",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
resourceOwner: "resourceOwner",
objectType: static.ObjectTypeUserAvatar,
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
err := c.RemoveObjects(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.objectType)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveObjects() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_dbStorage_RemoveInstanceObjects(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeInstanceObjectsStmt,
nil,
"instanceID",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
err := c.RemoveInstanceObjects(tt.args.ctx, tt.args.instanceID)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveInstanceObjects() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
type db struct {
mock sqlmock.Sqlmock
db *sql.DB
}
func prepareDB(t *testing.T, expectations ...expectation) db {
t.Helper()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("unable to create sql mock: %v", err)
}
for _, expectation := range expectations {
expectation(mock)
}
return db{
mock: mock,
db: client,
}
}
type expectation func(m sqlmock.Sqlmock)
func expectExists(query string, value bool, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnRows(m.NewRows([]string{"exists"}).AddRow(value))
}
}
func expectQueryErr(query string, err error, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
}
}
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
}
}
func expectExec(stmt string, err error, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
if err != nil {
query.WillReturnError(err)
return
}
query.WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectBegin(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectBegin()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectCommit(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectCommit()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectRollback(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectRollback()
if err != nil {
query.WillReturnError(err)
}
}
}