chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,32 @@
package config
import (
"database/sql"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/static"
"github.com/zitadel/zitadel/internal/static/database"
"github.com/zitadel/zitadel/internal/static/s3"
"github.com/zitadel/zitadel/internal/zerrors"
)
type AssetStorageConfig struct {
Type string
Cache middleware.CacheConfig
Config map[string]interface{} `mapstructure:",remain"`
}
func (a *AssetStorageConfig) NewStorage(client *sql.DB) (static.Storage, error) {
t, ok := storage[a.Type]
if !ok {
return nil, zerrors.ThrowInternalf(nil, "STATIC-dsbjh", "config type %s not supported", a.Type)
}
return t(client, a.Config)
}
var storage = map[string]static.CreateStorage{
"db": database.NewStorage,
"": database.NewStorage,
"s3": s3.NewStorage,
}

View File

@@ -0,0 +1,199 @@
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"time"
"github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/static"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ static.Storage = (*storage)(nil)
const (
assetsTable = "system.assets"
AssetColInstanceID = "instance_id"
AssetColType = "asset_type"
AssetColLocation = "location"
AssetColResourceOwner = "resource_owner"
AssetColName = "name"
AssetColData = "data"
AssetColContentType = "content_type"
AssetColHash = "hash"
AssetColUpdatedAt = "updated_at"
)
type storage struct {
client *sql.DB
}
func NewStorage(client *sql.DB, _ map[string]interface{}) (static.Storage, error) {
return &storage{client: client}, nil
}
func (c *storage) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) {
data, err := io.ReadAll(object)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-Dfwvq", "Errors.Internal")
}
stmt, args, err := squirrel.Insert(assetsTable).
Columns(AssetColInstanceID, AssetColResourceOwner, AssetColName, AssetColType, AssetColContentType, AssetColData, AssetColUpdatedAt).
Values(instanceID, resourceOwner, name, objectType.String(), contentType, data, "now()").
Suffix(fmt.Sprintf(
"ON CONFLICT (%s, %s, %s) DO UPDATE"+
" SET %s = $5, %s = $6"+
" RETURNING %s, %s", AssetColInstanceID, AssetColResourceOwner, AssetColName, AssetColContentType, AssetColData, AssetColHash, AssetColUpdatedAt)).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-32DG1", "Errors.Internal")
}
var hash string
var updatedAt time.Time
err = c.client.QueryRowContext(ctx, stmt, args...).Scan(&hash, &updatedAt)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-D2g2q", "Errors.Internal")
}
return &static.Asset{
InstanceID: instanceID,
Name: name,
Hash: hash,
Size: objectSize,
LastModified: updatedAt,
Location: location,
ContentType: contentType,
}, nil
}
func (c *storage) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) {
query, args, err := squirrel.Select(AssetColData, AssetColContentType, AssetColHash, AssetColUpdatedAt).
From(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColName: name,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, nil, zerrors.ThrowInternal(err, "DATAB-GE3hz", "Errors.Internal")
}
var data []byte
asset := &static.Asset{
InstanceID: instanceID,
ResourceOwner: resourceOwner,
Name: name,
}
err = c.client.QueryRowContext(ctx, query, args...).
Scan(
&data,
&asset.ContentType,
&asset.Hash,
&asset.LastModified,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, zerrors.ThrowNotFound(err, "DATAB-pCP8P", "Errors.Assets.Object.NotFound")
}
return nil, nil, zerrors.ThrowInternal(err, "DATAB-Sfgb3", "Errors.Assets.Object.GetFailed")
}
asset.Size = int64(len(data))
return data,
func() (*static.Asset, error) {
return asset, nil
},
nil
}
func (c *storage) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) {
query, args, err := squirrel.Select(AssetColContentType, AssetColLocation, "length("+AssetColData+")", AssetColHash, AssetColUpdatedAt).
From(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColName: name,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-rggt2", "Errors.Internal")
}
asset := &static.Asset{
InstanceID: instanceID,
ResourceOwner: resourceOwner,
Name: name,
}
err = c.client.QueryRowContext(ctx, query, args...).
Scan(
&asset.ContentType,
&asset.Location,
&asset.Size,
&asset.Hash,
&asset.LastModified,
)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-Dbh2s", "Errors.Internal")
}
return asset, nil
}
func (c *storage) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error {
stmt, args, err := squirrel.Delete(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColName: name,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Sgvwq", "Errors.Internal")
}
_, err = c.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-RHNgf", "Errors.Assets.Object.RemoveFailed")
}
return nil
}
func (c *storage) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error {
stmt, args, err := squirrel.Delete(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
AssetColResourceOwner: resourceOwner,
AssetColType: objectType.String(),
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Sfgeq", "Errors.Internal")
}
_, err = c.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Efgt2", "Errors.Assets.Object.RemoveFailed")
}
return nil
}
func (c *storage) RemoveInstanceObjects(ctx context.Context, instanceID string) error {
stmt, args, err := squirrel.Delete(assetsTable).
Where(squirrel.Eq{
AssetColInstanceID: instanceID,
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Sfgeq", "Errors.Internal")
}
_, err = c.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DATAB-Efgt2", "Errors.Assets.Object.RemoveFailed")
}
return nil
}

View File

@@ -0,0 +1,360 @@
package database
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"io"
"reflect"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/static"
)
var (
testNow = time.Now()
)
const (
createObjectStmt = "INSERT INTO system.assets" +
" (instance_id,resource_owner,name,asset_type,content_type,data,updated_at)" +
" VALUES ($1,$2,$3,$4,$5,$6,$7)" +
" ON CONFLICT (instance_id, resource_owner, name) DO UPDATE SET" +
" content_type = $5, data = $6" +
" RETURNING hash"
removeObjectStmt = "DELETE FROM system.assets" +
" WHERE instance_id = $1" +
" AND name = $2" +
" AND resource_owner = $3"
removeObjectsStmt = "DELETE FROM system.assets" +
" WHERE asset_type = $1" +
" AND instance_id = $2" +
" AND resource_owner = $3"
removeInstanceObjectsStmt = "DELETE FROM system.assets" +
" WHERE instance_id = $1"
)
func Test_dbStorage_CreateObject(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
location string
resourceOwner string
name string
contentType string
objectType static.ObjectType
data io.Reader
objectSize int64
}
tests := []struct {
name string
fields fields
args args
want *static.Asset
wantErr bool
}{
{
"create ok",
fields{
client: prepareDB(t,
expectQuery(
createObjectStmt,
[]string{
"hash",
"updated_at",
},
[][]driver.Value{
{
"md5Hash",
testNow,
},
},
"instanceID",
"resourceOwner",
"name",
static.ObjectTypeUserAvatar.String(),
"contentType",
[]byte("test"),
"now()",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
location: "location",
resourceOwner: "resourceOwner",
name: "name",
contentType: "contentType",
objectType: static.ObjectTypeUserAvatar,
data: bytes.NewReader([]byte("test")),
objectSize: 4,
},
&static.Asset{
InstanceID: "instanceID",
Name: "name",
Hash: "md5Hash",
Size: 4,
LastModified: testNow,
Location: "location",
ContentType: "contentType",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
got, err := c.PutObject(tt.args.ctx, tt.args.instanceID, tt.args.location, tt.args.resourceOwner, tt.args.name, tt.args.contentType, tt.args.objectType, tt.args.data, tt.args.objectSize)
if (err != nil) != tt.wantErr {
t.Errorf("CreateObject() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateObject() got = %v, want %v", got, tt.want)
}
})
}
}
func Test_dbStorage_RemoveObject(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
resourceOwner string
name string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeObjectStmt,
nil,
"instanceID",
"name",
"resourceOwner",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
resourceOwner: "resourceOwner",
name: "name",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
err := c.RemoveObject(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveObject() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_dbStorage_RemoveObjects(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
resourceOwner string
objectType static.ObjectType
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeObjectsStmt,
nil, static.ObjectTypeUserAvatar.String(),
"instanceID",
"resourceOwner",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
resourceOwner: "resourceOwner",
objectType: static.ObjectTypeUserAvatar,
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
err := c.RemoveObjects(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.objectType)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveObjects() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_dbStorage_RemoveInstanceObjects(t *testing.T) {
type fields struct {
client db
}
type args struct {
ctx context.Context
instanceID string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
"remove ok",
fields{
client: prepareDB(t,
expectExec(
removeInstanceObjectsStmt,
nil,
"instanceID",
)),
},
args{
ctx: context.Background(),
instanceID: "instanceID",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &storage{
client: tt.fields.client.db,
}
err := c.RemoveInstanceObjects(tt.args.ctx, tt.args.instanceID)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveInstanceObjects() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
type db struct {
mock sqlmock.Sqlmock
db *sql.DB
}
func prepareDB(t *testing.T, expectations ...expectation) db {
t.Helper()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("unable to create sql mock: %v", err)
}
for _, expectation := range expectations {
expectation(mock)
}
return db{
mock: mock,
db: client,
}
}
type expectation func(m sqlmock.Sqlmock)
func expectExists(query string, value bool, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnRows(m.NewRows([]string{"exists"}).AddRow(value))
}
}
func expectQueryErr(query string, err error, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
}
}
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
}
}
func expectExec(stmt string, err error, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
if err != nil {
query.WillReturnError(err)
return
}
query.WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectBegin(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectBegin()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectCommit(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectCommit()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectRollback(err error) expectation {
return func(m sqlmock.Sqlmock) {
query := m.ExpectRollback()
if err != nil {
query.WillReturnError(err)
}
}
}

View File

@@ -0,0 +1,3 @@
package static
//go:generate mockgen -source storage.go -destination ./mock/storage_mock.go -package mock

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,130 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: storage.go
//
// Generated by this command:
//
// mockgen -source storage.go -destination ./mock/storage_mock.go -package mock
//
// Package mock is a generated GoMock package.
package mock
import (
context "context"
io "io"
reflect "reflect"
static "github.com/zitadel/zitadel/internal/static"
gomock "go.uber.org/mock/gomock"
)
// MockStorage is a mock of Storage interface.
type MockStorage struct {
ctrl *gomock.Controller
recorder *MockStorageMockRecorder
}
// MockStorageMockRecorder is the mock recorder for MockStorage.
type MockStorageMockRecorder struct {
mock *MockStorage
}
// NewMockStorage creates a new mock instance.
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
mock := &MockStorage{ctrl: ctrl}
mock.recorder = &MockStorageMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
// GetObject mocks base method.
func (m *MockStorage) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetObject", ctx, instanceID, resourceOwner, name)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(func() (*static.Asset, error))
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetObject indicates an expected call of GetObject.
func (mr *MockStorageMockRecorder) GetObject(ctx, instanceID, resourceOwner, name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObject", reflect.TypeOf((*MockStorage)(nil).GetObject), ctx, instanceID, resourceOwner, name)
}
// GetObjectInfo mocks base method.
func (m *MockStorage) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetObjectInfo", ctx, instanceID, resourceOwner, name)
ret0, _ := ret[0].(*static.Asset)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetObjectInfo indicates an expected call of GetObjectInfo.
func (mr *MockStorageMockRecorder) GetObjectInfo(ctx, instanceID, resourceOwner, name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjectInfo", reflect.TypeOf((*MockStorage)(nil).GetObjectInfo), ctx, instanceID, resourceOwner, name)
}
// PutObject mocks base method.
func (m *MockStorage) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PutObject", ctx, instanceID, location, resourceOwner, name, contentType, objectType, object, objectSize)
ret0, _ := ret[0].(*static.Asset)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PutObject indicates an expected call of PutObject.
func (mr *MockStorageMockRecorder) PutObject(ctx, instanceID, location, resourceOwner, name, contentType, objectType, object, objectSize any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutObject", reflect.TypeOf((*MockStorage)(nil).PutObject), ctx, instanceID, location, resourceOwner, name, contentType, objectType, object, objectSize)
}
// RemoveInstanceObjects mocks base method.
func (m *MockStorage) RemoveInstanceObjects(ctx context.Context, instanceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveInstanceObjects", ctx, instanceID)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveInstanceObjects indicates an expected call of RemoveInstanceObjects.
func (mr *MockStorageMockRecorder) RemoveInstanceObjects(ctx, instanceID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveInstanceObjects", reflect.TypeOf((*MockStorage)(nil).RemoveInstanceObjects), ctx, instanceID)
}
// RemoveObject mocks base method.
func (m *MockStorage) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveObject", ctx, instanceID, resourceOwner, name)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveObject indicates an expected call of RemoveObject.
func (mr *MockStorageMockRecorder) RemoveObject(ctx, instanceID, resourceOwner, name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveObject", reflect.TypeOf((*MockStorage)(nil).RemoveObject), ctx, instanceID, resourceOwner, name)
}
// RemoveObjects mocks base method.
func (m *MockStorage) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveObjects", ctx, instanceID, resourceOwner, objectType)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveObjects indicates an expected call of RemoveObjects.
func (mr *MockStorageMockRecorder) RemoveObjects(ctx, instanceID, resourceOwner, objectType any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveObjects", reflect.TypeOf((*MockStorage)(nil).RemoveObjects), ctx, instanceID, resourceOwner, objectType)
}

View File

@@ -0,0 +1,63 @@
package mock
import (
"context"
"io"
"testing"
"time"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/static"
"github.com/zitadel/zitadel/internal/zerrors"
)
func NewStorage(t *testing.T) *MockStorage {
return NewMockStorage(gomock.NewController(t))
}
func (m *MockStorage) ExpectPutObject() *MockStorage {
m.EXPECT().
PutObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) {
hash, _ := io.ReadAll(object)
return &static.Asset{
InstanceID: instanceID,
Name: name,
Hash: string(hash),
Size: objectSize,
LastModified: time.Now(),
Location: location,
ContentType: contentType,
}, nil
})
return m
}
func (m *MockStorage) ExpectPutObjectError() *MockStorage {
m.EXPECT().
PutObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, zerrors.ThrowInternal(nil, "", ""))
return m
}
func (m *MockStorage) ExpectRemoveObjectNoError() *MockStorage {
m.EXPECT().
RemoveObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
return m
}
func (m *MockStorage) ExpectRemoveObjectsNoError() *MockStorage {
m.EXPECT().
RemoveObjects(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
return m
}
func (m *MockStorage) ExpectRemoveObjectError() *MockStorage {
m.EXPECT().
RemoveObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(zerrors.ThrowInternal(nil, "", ""))
return m
}

View File

@@ -0,0 +1,51 @@
package s3
import (
"database/sql"
"encoding/json"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/zitadel/zitadel/internal/static"
"github.com/zitadel/zitadel/internal/zerrors"
)
type Config struct {
Endpoint string
AccessKeyID string
SecretAccessKey string
SSL bool
Location string
BucketPrefix string
MultiDelete bool
}
func (c *Config) NewStorage() (static.Storage, error) {
minioClient, err := minio.New(c.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(c.AccessKeyID, c.SecretAccessKey, ""),
Secure: c.SSL,
Region: c.Location,
})
if err != nil {
return nil, zerrors.ThrowInternal(err, "MINIO-2n9fs", "Errors.Assets.Store.NotInitialized")
}
return &Minio{
Client: minioClient,
Location: c.Location,
BucketPrefix: c.BucketPrefix,
MultiDelete: c.MultiDelete,
}, nil
}
func NewStorage(_ *sql.DB, rawConfig map[string]interface{}) (static.Storage, error) {
configData, err := json.Marshal(rawConfig)
if err != nil {
return nil, zerrors.ThrowInternal(err, "MINIO-Ef2f2", "could not map config")
}
c := new(Config)
if err := json.Unmarshal(configData, c); err != nil {
return nil, zerrors.ThrowInternal(err, "MINIO-GB4nw", "could not map config")
}
return c.NewStorage()
}

View File

@@ -0,0 +1,187 @@
package s3
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/minio/minio-go/v7"
"github.com/zitadel/logging"
"golang.org/x/sync/errgroup"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/static"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ static.Storage = (*Minio)(nil)
type Minio struct {
Client *minio.Client
Location string
BucketPrefix string
MultiDelete bool
}
func (m *Minio) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) {
err := m.createBucket(ctx, instanceID, location)
if err != nil && !zerrors.IsErrorAlreadyExists(err) {
return nil, err
}
bucketName := m.prefixBucketName(instanceID)
objectName := fmt.Sprintf("%s/%s", resourceOwner, name)
info, err := m.Client.PutObject(ctx, bucketName, objectName, object, objectSize, minio.PutObjectOptions{ContentType: contentType})
if err != nil {
return nil, zerrors.ThrowInternal(err, "MINIO-590sw", "Errors.Assets.Object.PutFailed")
}
return &static.Asset{
InstanceID: info.Bucket,
ResourceOwner: resourceOwner,
Name: info.Key,
Hash: info.ETag,
Size: info.Size,
LastModified: info.LastModified,
Location: info.Location,
ContentType: contentType,
}, nil
}
func (m *Minio) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) {
bucketName := m.prefixBucketName(instanceID)
objectName := fmt.Sprintf("%s/%s", resourceOwner, name)
object, err := m.Client.GetObject(ctx, bucketName, objectName, minio.GetObjectOptions{})
if err != nil {
return nil, nil, zerrors.ThrowInternal(err, "MINIO-VGDgv", "Errors.Assets.Object.GetFailed")
}
info := func() (*static.Asset, error) {
info, err := object.Stat()
if err != nil {
return nil, zerrors.ThrowInternal(err, "MINIO-F96xF", "Errors.Assets.Object.GetFailed")
}
return m.objectToAssetInfo(instanceID, resourceOwner, info), nil
}
asset, err := io.ReadAll(object)
if err != nil {
return nil, nil, zerrors.ThrowInternal(err, "MINIO-SFef1", "Errors.Assets.Object.GetFailed")
}
return asset, info, nil
}
func (m *Minio) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) {
bucketName := m.prefixBucketName(instanceID)
objectName := fmt.Sprintf("%s/%s", resourceOwner, name)
objectInfo, err := m.Client.StatObject(ctx, bucketName, objectName, minio.StatObjectOptions{})
if err != nil {
if errResp := minio.ToErrorResponse(err); errResp.StatusCode == http.StatusNotFound {
return nil, zerrors.ThrowNotFound(err, "MINIO-Gdfh4", "Errors.Assets.Object.GetFailed")
}
return nil, zerrors.ThrowInternal(err, "MINIO-1vySX", "Errors.Assets.Object.GetFailed")
}
return m.objectToAssetInfo(instanceID, resourceOwner, objectInfo), nil
}
func (m *Minio) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error {
bucketName := m.prefixBucketName(instanceID)
objectName := fmt.Sprintf("%s/%s", resourceOwner, name)
err := m.Client.RemoveObject(ctx, bucketName, objectName, minio.RemoveObjectOptions{})
if err != nil {
return zerrors.ThrowInternal(err, "MINIO-x85RT", "Errors.Assets.Object.RemoveFailed")
}
return nil
}
func (m *Minio) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error {
bucketName := m.prefixBucketName(instanceID)
objectsCh := make(chan minio.ObjectInfo)
g := new(errgroup.Group)
var path string
switch objectType {
case static.ObjectTypeStyling:
path = domain.LabelPolicyPrefix + "/"
default:
return nil
}
g.Go(func() error {
defer close(objectsCh)
objects, cancel := m.listObjects(ctx, bucketName, resourceOwner, true)
for object := range objects {
if err := object.Err; err != nil {
cancel()
if errResp := minio.ToErrorResponse(err); errResp.StatusCode == http.StatusNotFound {
logging.WithFields("bucketName", bucketName, "path", path).Warn("list objects for remove failed with not found")
continue
}
return zerrors.ThrowInternal(object.Err, "MINIO-WQF32", "Errors.Assets.Object.ListFailed")
}
objectsCh <- object
}
return nil
})
if m.MultiDelete {
for objError := range m.Client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
return zerrors.ThrowInternal(objError.Err, "MINIO-Sfdgr", "Errors.Assets.Object.RemoveFailed")
}
return g.Wait()
}
for objectInfo := range objectsCh {
if err := m.Client.RemoveObject(ctx, bucketName, objectInfo.Key, minio.RemoveObjectOptions{GovernanceBypass: true}); err != nil {
return zerrors.ThrowInternal(err, "MINIO-GVgew", "Errors.Assets.Object.RemoveFailed")
}
}
return g.Wait()
}
func (m *Minio) RemoveInstanceObjects(ctx context.Context, instanceID string) error {
bucketName := m.prefixBucketName(instanceID)
return m.Client.RemoveBucket(ctx, bucketName)
}
func (m *Minio) createBucket(ctx context.Context, name, location string) error {
if location == "" {
location = m.Location
}
name = m.prefixBucketName(name)
exists, err := m.Client.BucketExists(ctx, name)
if err != nil {
logging.WithFields("bucketname", name).WithError(err).Error("cannot check if bucket exists")
return zerrors.ThrowInternal(err, "MINIO-1b8fs", "Errors.Assets.Bucket.Internal")
}
if exists {
return zerrors.ThrowAlreadyExists(nil, "MINIO-9n3MK", "Errors.Assets.Bucket.AlreadyExists")
}
err = m.Client.MakeBucket(ctx, name, minio.MakeBucketOptions{Region: location})
if err != nil {
return zerrors.ThrowInternal(err, "MINIO-4m90d", "Errors.Assets.Bucket.CreateFailed")
}
return nil
}
func (m *Minio) listObjects(ctx context.Context, bucketName, prefix string, recursive bool) (<-chan minio.ObjectInfo, context.CancelFunc) {
ctxCancel, cancel := context.WithCancel(ctx)
return m.Client.ListObjects(ctxCancel, bucketName, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: recursive,
}), cancel
}
func (m *Minio) objectToAssetInfo(bucketName string, resourceOwner string, object minio.ObjectInfo) *static.Asset {
return &static.Asset{
InstanceID: bucketName,
ResourceOwner: resourceOwner,
Name: object.Key,
Hash: object.ETag,
Size: object.Size,
LastModified: object.LastModified,
ContentType: object.ContentType,
}
}
func (m *Minio) prefixBucketName(name string) string {
return strings.ToLower(m.BucketPrefix + "-" + name)
}

View File

@@ -0,0 +1,53 @@
package static
import (
"context"
"database/sql"
"io"
"time"
)
type CreateStorage func(client *sql.DB, rawConfig map[string]interface{}) (Storage, error)
type Storage interface {
PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType ObjectType, object io.Reader, objectSize int64) (*Asset, error)
GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*Asset, error), error)
GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*Asset, error)
RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error
RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType ObjectType) error
RemoveInstanceObjects(ctx context.Context, instanceID string) error
//TODO: add functionality to move asset location
}
type ObjectType int32
const (
ObjectTypeUserAvatar ObjectType = iota
ObjectTypeStyling
)
func (o ObjectType) String() string {
switch o {
case ObjectTypeUserAvatar:
return "0"
case ObjectTypeStyling:
return "1"
default:
return ""
}
}
type Asset struct {
InstanceID string
ResourceOwner string
Name string
Hash string
Size int64
LastModified time.Time
Location string
ContentType string
}
func (a *Asset) VersionedName() string {
return a.Name + "?v=" + a.Hash
}