perf: query data AS OF SYSTEM TIME (#5231)

Queries the data in the storage layser at the timestamp when the call hit the API layer
This commit is contained in:
Silvan
2023-02-27 22:36:43 +01:00
committed by GitHub
parent 80003939ad
commit e38abdcdf3
170 changed files with 3101 additions and 3169 deletions

View File

@@ -2,11 +2,11 @@ package eventsourcing
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/admin/repository/eventsourcing/eventstore"
"github.com/zitadel/zitadel/internal/admin/repository/eventsourcing/spooler"
admin_view "github.com/zitadel/zitadel/internal/admin/repository/eventsourcing/view"
"github.com/zitadel/zitadel/internal/database"
eventstore2 "github.com/zitadel/zitadel/internal/eventstore"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
es_spol "github.com/zitadel/zitadel/internal/eventstore/v1/spooler"
@@ -23,7 +23,7 @@ type EsRepository struct {
eventstore.AdministratorRepo
}
func Start(ctx context.Context, conf Config, static static.Storage, dbClient *sql.DB, esV2 *eventstore2.Eventstore) (*EsRepository, error) {
func Start(ctx context.Context, conf Config, static static.Storage, dbClient *database.DB, esV2 *eventstore2.Eventstore) (*EsRepository, error) {
es, err := v1.Start(dbClient)
if err != nil {
return nil, err

View File

@@ -2,10 +2,10 @@ package spooler
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/admin/repository/eventsourcing/handler"
"github.com/zitadel/zitadel/internal/admin/repository/eventsourcing/view"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
"github.com/zitadel/zitadel/internal/eventstore/v1/spooler"
@@ -20,11 +20,11 @@ type SpoolerConfig struct {
Handlers handler.Configs
}
func StartSpooler(ctx context.Context, c SpoolerConfig, es v1.Eventstore, esV2 *eventstore.Eventstore, view *view.View, sql *sql.DB, static static.Storage) *spooler.Spooler {
func StartSpooler(ctx context.Context, c SpoolerConfig, es v1.Eventstore, esV2 *eventstore.Eventstore, view *view.View, sql *database.DB, static static.Storage) *spooler.Spooler {
spoolerConfig := spooler.Config{
Eventstore: es,
EventstoreV2: esV2,
Locker: &locker{dbClient: sql},
Locker: &locker{dbClient: sql.DB},
ConcurrentWorkers: c.ConcurrentWorkers,
ConcurrentInstances: c.ConcurrentInstances,
ViewHandlers: handler.Register(ctx, c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, static),

View File

@@ -1,16 +1,15 @@
package view
import (
"database/sql"
"github.com/jinzhu/gorm"
"github.com/zitadel/zitadel/internal/database"
)
type View struct {
Db *gorm.DB
}
func StartView(sqlClient *sql.DB) (*View, error) {
func StartView(sqlClient *database.DB) (*View, error) {
gorm, err := gorm.Open("postgres", sqlClient)
if err != nil {
return nil, err

View File

@@ -82,7 +82,7 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, code
http.Error(w, err.Error(), code)
}
func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authConfig authz.Config, idGenerator id.Generator, storage static.Storage, queries *query.Queries, instanceInterceptor, assetCacheInterceptor, accessInterceptor func(handler http.Handler) http.Handler) http.Handler {
func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authConfig authz.Config, idGenerator id.Generator, storage static.Storage, queries *query.Queries, callDurationInterceptor, instanceInterceptor, assetCacheInterceptor, accessInterceptor func(handler http.Handler) http.Handler) http.Handler {
h := &Handler{
commands: commands,
errorHandler: DefaultErrorHandler,
@@ -94,7 +94,7 @@ func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authC
verifier.RegisterServer("Assets-API", "assets", AssetsService_AuthMethods)
router := mux.NewRouter()
router.Use(instanceInterceptor, assetCacheInterceptor, accessInterceptor)
router.Use(callDurationInterceptor, instanceInterceptor, assetCacheInterceptor, accessInterceptor)
RegisterRoutes(router, h)
router.PathPrefix("/{owner}").Methods("GET").HandlerFunc(DownloadHandleFunc(h, h.GetFile()))
return http_util.CopyHeadersToContext(http_mw.CORSInterceptor(router))

View File

@@ -0,0 +1,17 @@
package authz
import (
"context"
"time"
)
func Detach(ctx context.Context) context.Context { return detachedContext{ctx} }
type detachedContext struct {
parent context.Context
}
func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (v detachedContext) Done() <-chan struct{} { return nil }
func (v detachedContext) Err() error { return nil }
func (v detachedContext) Value(key interface{}) interface{} { return v.parent.Value(key) }

View File

@@ -0,0 +1,38 @@
package call
import (
"context"
"time"
)
type durationKey struct{}
var key *durationKey = (*durationKey)(nil)
// WithTimestamp sets [time.Now()] adds the call field to the context
// if it's not already set
func WithTimestamp(parent context.Context) context.Context {
if parent.Value(key) != nil {
return parent
}
return context.WithValue(parent, key, time.Now())
}
// FromContext returns the [time.Time] the call hit the api
func FromContext(ctx context.Context) (t time.Time) {
value := ctx.Value(key)
if t, ok := value.(time.Time); ok {
return t
}
return t
}
// Took returns the time the call took so far
func Took(ctx context.Context) time.Duration {
start := FromContext(ctx)
if start.IsZero() {
return 0
}
return time.Since(start)
}

View File

@@ -0,0 +1,119 @@
package call
import (
"context"
"testing"
"time"
)
func TestTook(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
args args
startIsZero bool
}{
{
name: "no start",
args: args{
ctx: context.Background(),
},
startIsZero: true,
},
{
name: "with start",
args: args{
ctx: WithTimestamp(context.Background()),
},
startIsZero: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Took(tt.args.ctx)
if tt.startIsZero && got != 0 {
t.Errorf("Duration should be 0 but was %v", got)
}
if !tt.startIsZero && got <= 0 {
t.Errorf("Duration should be greater 0 but was %d", got)
}
})
}
}
func TestFromContext(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
args args
isZero bool
}{
{
name: "no start",
args: args{
ctx: context.Background(),
},
isZero: true,
},
{
name: "with start",
args: args{
ctx: WithTimestamp(context.Background()),
},
isZero: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromContext(tt.args.ctx)
if tt.isZero != got.IsZero() {
t.Errorf("Time is zero should be %v but was %v", tt.isZero, got.IsZero())
}
})
}
}
func TestWithTimestamp(t *testing.T) {
start := time.Date(2019, 4, 29, 0, 0, 0, 0, time.UTC)
type args struct {
ctx context.Context
}
tests := []struct {
name string
args args
noPrevious bool
}{
{
name: "fresh context",
args: args{
ctx: context.WithValue(context.Background(), key, start),
},
noPrevious: true,
},
{
name: "with start",
args: args{
ctx: WithTimestamp(context.Background()),
},
noPrevious: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := WithTimestamp(tt.args.ctx)
val := got.Value(key).(time.Time)
if !tt.noPrevious && val.Before(start) {
t.Errorf("time should be now not %v", val)
}
if tt.noPrevious && val.After(start) {
t.Errorf("time should be start not %v", val)
}
})
}
}

View File

@@ -88,17 +88,6 @@ func (c *count) getProgress() string {
"project_grant_members " + strconv.Itoa(c.projectGrantMemberCount) + "/" + strconv.Itoa(c.projectGrantMemberLen)
}
func Detach(ctx context.Context) context.Context { return detachedContext{ctx} }
type detachedContext struct {
parent context.Context
}
func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (v detachedContext) Done() <-chan struct{} { return nil }
func (v detachedContext) Err() error { return nil }
func (v detachedContext) Value(key interface{}) interface{} { return v.parent.Value(key) }
func (s *Server) ImportData(ctx context.Context, req *admin_pb.ImportDataRequest) (_ *admin_pb.ImportDataResponse, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@@ -169,7 +158,7 @@ func (s *Server) ImportData(ctx context.Context, req *admin_pb.ImportDataRequest
if err != nil {
return nil, err
}
dctx := Detach(ctx)
dctx := authz.Detach(ctx)
go func() {
ch := make(chan importResponse, 1)
ctxTimeout, cancel := context.WithTimeout(dctx, timeoutDuration)

View File

@@ -71,6 +71,7 @@ func CreateGateway(ctx context.Context, g Gateway, port uint16, http1HostName st
}
func addInterceptors(handler http.Handler, http1HostName string) http.Handler {
handler = http_mw.CallDurationHandler(handler)
handler = http1Host(handler, http1HostName)
handler = http_mw.CORSInterceptor(handler)
handler = http_mw.DefaultTelemetryHandler(handler)

View File

@@ -0,0 +1,16 @@
package middleware
import (
"context"
"google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/call"
)
func CallDurationHandler() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx = call.WithTimestamp(ctx)
return handler(ctx, req)
}
}

View File

@@ -35,6 +35,7 @@ func CreateServer(
serverOptions := []grpc.ServerOption{
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
middleware.CallDurationHandler(),
middleware.DefaultTracingServer(),
middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
middleware.NoCacheInterceptor(),

View File

@@ -0,0 +1,13 @@
package middleware
import (
"net/http"
"github.com/zitadel/zitadel/internal/api/call"
)
func CallDurationHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(call.WithTimestamp(r.Context())))
})
}

View File

@@ -187,6 +187,7 @@ func (o *OPStorage) getMaxKeySequence(ctx context.Context) (uint64, error) {
return o.eventstore.LatestSequence(ctx,
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
AllowTimeTravel().
AddQuery().
AggregateTypes(keypair.AggregateType, instance.AggregateType).
Builder(),

View File

@@ -2,7 +2,6 @@ package oidc
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
@@ -18,6 +17,7 @@ import (
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
@@ -73,7 +73,7 @@ type OPStorage struct {
assetAPIPrefix func(ctx context.Context) string
}
func NewProvider(ctx context.Context, config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *sql.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
func NewProvider(ctx context.Context, config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *database.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
opConfig, err := createOPConfig(config, defaultLogoutRedirectURI, cryptoKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w")
@@ -169,7 +169,7 @@ func customEndpoints(endpointConfig *EndpointConfig) []op.Option {
return options
}
func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, projections *sql.DB, externalSecure bool) *OPStorage {
func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
return &OPStorage{
repo: repo,
command: command,
@@ -182,7 +182,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries,
defaultRefreshTokenIdleExpiration: config.DefaultRefreshTokenIdleExpiration,
defaultRefreshTokenExpiration: config.DefaultRefreshTokenExpiration,
encAlg: encAlg,
locker: crdb.NewLocker(projections, locksTable, signingKey),
locker: crdb.NewLocker(db.DB, locksTable, signingKey),
assetAPIPrefix: assets.AssetAPI(externalSecure),
}
}

View File

@@ -2,7 +2,6 @@ package saml
import (
"context"
"database/sql"
"fmt"
"net/http"
@@ -14,6 +13,7 @@ import (
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/query"
@@ -38,7 +38,7 @@ func NewProvider(
encAlg crypto.EncryptionAlgorithm,
certEncAlg crypto.EncryptionAlgorithm,
es *eventstore.Eventstore,
projections *sql.DB,
projections *database.DB,
instanceHandler,
userAgentCookie,
accessHandler func(http.Handler) http.Handler,
@@ -89,12 +89,12 @@ func newStorage(
encAlg crypto.EncryptionAlgorithm,
certEncAlg crypto.EncryptionAlgorithm,
es *eventstore.Eventstore,
projections *sql.DB,
db *database.DB,
) (*Storage, error) {
return &Storage{
encAlg: encAlg,
certEncAlg: certEncAlg,
locker: crdb.NewLocker(projections, locksTable, signingKey),
locker: crdb.NewLocker(db.DB, locksTable, signingKey),
eventstore: es,
repo: repo,
command: command,

View File

@@ -88,7 +88,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
return f, nil
}
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
fSys, err := fs.Sub(static, "static")
if err != nil {
return nil, err
@@ -103,7 +103,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, inst
handler := mux.NewRouter()
handler.Use(instanceHandler, security, accessInterceptor)
handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure)
environmentJSON, err := createEnvironmentJSON(url, issuer(r), authz.GetInstance(r.Context()).ConsoleClientID(), customerPortal)

View File

@@ -67,8 +67,8 @@ func CreateLogin(config Config,
userAgentCookie,
issuerInterceptor,
oidcInstanceHandler,
samlInstanceHandler mux.MiddlewareFunc,
assetCache mux.MiddlewareFunc,
samlInstanceHandler,
assetCache,
accessHandler mux.MiddlewareFunc,
userCodeAlg crypto.EncryptionAlgorithm,
idpConfigAlg crypto.EncryptionAlgorithm,

View File

@@ -2,7 +2,6 @@ package eventsourcing
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/eventstore"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/spooler"
@@ -11,6 +10,7 @@ import (
"github.com/zitadel/zitadel/internal/command"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
eventstore2 "github.com/zitadel/zitadel/internal/eventstore"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
es_spol "github.com/zitadel/zitadel/internal/eventstore/v1/spooler"
@@ -34,7 +34,7 @@ type EsRepository struct {
eventstore.OrgRepository
}
func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, command *command.Commands, queries *query.Queries, dbClient *sql.DB, esV2 *eventstore2.Eventstore, oidcEncryption crypto.EncryptionAlgorithm, userEncryption crypto.EncryptionAlgorithm) (*EsRepository, error) {
func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, command *command.Commands, queries *query.Queries, dbClient *database.DB, esV2 *eventstore2.Eventstore, oidcEncryption crypto.EncryptionAlgorithm, userEncryption crypto.EncryptionAlgorithm) (*EsRepository, error) {
es, err := v1.Start(dbClient)
if err != nil {
return nil, err

View File

@@ -2,11 +2,11 @@ package spooler
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/handler"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
"github.com/zitadel/zitadel/internal/eventstore/v1/spooler"
@@ -21,11 +21,11 @@ type SpoolerConfig struct {
Handlers handler.Configs
}
func StartSpooler(ctx context.Context, c SpoolerConfig, es v1.Eventstore, esV2 *eventstore.Eventstore, view *view.View, client *sql.DB, systemDefaults sd.SystemDefaults, queries *query.Queries) *spooler.Spooler {
func StartSpooler(ctx context.Context, c SpoolerConfig, es v1.Eventstore, esV2 *eventstore.Eventstore, view *view.View, client *database.DB, systemDefaults sd.SystemDefaults, queries *query.Queries) *spooler.Spooler {
spoolerConfig := spooler.Config{
Eventstore: es,
EventstoreV2: esV2,
Locker: &locker{dbClient: client},
Locker: &locker{dbClient: client.DB},
ConcurrentWorkers: c.ConcurrentWorkers,
ConcurrentInstances: c.ConcurrentInstances,
ViewHandlers: handler.Register(ctx, c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults, queries),

View File

@@ -21,15 +21,6 @@ func (v *View) UserByID(userID, instanceID string) (*model.UserView, error) {
return view.UserByID(v.Db, userTable, userID, instanceID)
}
func (v *View) UserByUsername(userName, instanceID string) (*model.UserView, error) {
query, err := query.NewUserUsernameSearchQuery(userName, query.TextEquals)
if err != nil {
return nil, err
}
return v.userByID(instanceID, query)
}
func (v *View) UserByLoginName(loginName, instanceID string) (*model.UserView, error) {
loginNameQuery, err := query.NewUserLoginNamesSearchQuery(loginName)
if err != nil {

View File

@@ -1,11 +1,10 @@
package view
import (
"database/sql"
"github.com/jinzhu/gorm"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
eventstore "github.com/zitadel/zitadel/internal/eventstore/v1"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/query"
@@ -19,7 +18,7 @@ type View struct {
es eventstore.Eventstore
}
func StartView(sqlClient *sql.DB, keyAlgorithm crypto.EncryptionAlgorithm, queries *query.Queries, idGenerator id.Generator, es eventstore.Eventstore) (*View, error) {
func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, queries *query.Queries, idGenerator id.Generator, es eventstore.Eventstore) (*View, error) {
gorm, err := gorm.Open("postgres", sqlClient)
if err != nil {
return nil, err

View File

@@ -9,15 +9,16 @@ import (
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
)
type AuthRequestCache struct {
client *sql.DB
client *database.DB
}
func Start(dbClient *sql.DB) *AuthRequestCache {
func Start(dbClient *database.DB) *AuthRequestCache {
return &AuthRequestCache{
client: dbClient,
}

View File

@@ -1,14 +1,13 @@
package authz
import (
"database/sql"
"github.com/zitadel/zitadel/internal/authz/repository"
"github.com/zitadel/zitadel/internal/authz/repository/eventsourcing"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/query"
)
func Start(queries *query.Queries, dbClient *sql.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure bool) (repository.Repository, error) {
func Start(queries *query.Queries, dbClient *database.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure bool) (repository.Repository, error) {
return eventsourcing.Start(queries, dbClient, keyEncryptionAlgorithm, externalSecure)
}

View File

@@ -2,12 +2,12 @@ package eventsourcing
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/authz/repository"
"github.com/zitadel/zitadel/internal/authz/repository/eventsourcing/eventstore"
authz_view "github.com/zitadel/zitadel/internal/authz/repository/eventsourcing/view"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/query"
@@ -18,7 +18,7 @@ type EsRepository struct {
eventstore.TokenVerifierRepo
}
func Start(queries *query.Queries, dbClient *sql.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure bool) (repository.Repository, error) {
func Start(queries *query.Queries, dbClient *database.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure bool) (repository.Repository, error) {
es, err := v1.Start(dbClient)
if err != nil {
return nil, err

View File

@@ -1,8 +1,7 @@
package view
import (
"database/sql"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/query"
@@ -15,7 +14,7 @@ type View struct {
idGenerator id.Generator
}
func StartView(sqlClient *sql.DB, idGenerator id.Generator, queries *query.Queries) (*View, error) {
func StartView(sqlClient *database.DB, idGenerator id.Generator, queries *query.Queries) (*View, error) {
gorm, err := gorm.Open("postgres", sqlClient)
if err != nil {
return nil, err

View File

@@ -9,7 +9,6 @@ type quotaWriteModel struct {
eventstore.WriteModel
unit quota.Unit
active bool
config *quota.AddedEvent
}
// newQuotaWriteModel aggregateId is filled by reducing unit matching events
@@ -43,11 +42,9 @@ func (wm *quotaWriteModel) Reduce() error {
case *quota.AddedEvent:
wm.AggregateID = e.Aggregate().ID
wm.active = true
wm.config = e
case *quota.RemovedEvent:
wm.AggregateID = e.Aggregate().ID
wm.active = false
wm.config = nil
}
}
return wm.WriteModel.Reduce()

View File

@@ -1,45 +0,0 @@
package command
import (
"time"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
type quotaNotificationsWriteModel struct {
eventstore.WriteModel
periodStart time.Time
latestNotifiedThresholds map[string]uint16
}
func newQuotaNotificationsWriteModel(aggregateId, instanceId, resourceOwner string, periodStart time.Time) *quotaNotificationsWriteModel {
return &quotaNotificationsWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: aggregateId,
InstanceID: instanceId,
ResourceOwner: resourceOwner,
},
periodStart: periodStart,
latestNotifiedThresholds: make(map[string]uint16),
}
}
func (wm *quotaNotificationsWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(wm.ResourceOwner).
AddQuery().
InstanceID(wm.InstanceID).
AggregateTypes(quota.AggregateType).
AggregateIDs(wm.AggregateID).
CreationDateAfter(wm.periodStart).
EventTypes(quota.NotifiedEventType).Builder()
}
func (wm *quotaNotificationsWriteModel) Reduce() error {
for _, event := range wm.Events {
e := event.(*quota.NotifiedEvent)
wm.latestNotifiedThresholds[e.ID] = e.Threshold
}
return wm.WriteModel.Reduce()
}

View File

@@ -1,25 +0,0 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func (c *Commands) GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (*quota.AddedEvent, time.Time, error) {
wm, err := c.getQuotaWriteModel(ctx, instanceID, instanceID, unit)
if err != nil || !wm.active {
return nil, time.Time{}, err
}
return wm.config, pushPeriodStart(wm.config.From, wm.config.ResetInterval, time.Now()), nil
}
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
next := from.Add(interval)
if next.After(now) {
return from
}
return pushPeriodStart(next, interval, now)
}

View File

@@ -2,6 +2,7 @@ package cockroach
import (
"database/sql"
"fmt"
"strconv"
"strings"
"time"
@@ -89,6 +90,15 @@ func (c *Config) Type() string {
return "cockroach"
}
func (c *Config) Timetravel(d time.Duration) string {
// verify that it is at least 1 micro second
if d < time.Microsecond {
d = time.Microsecond
}
return fmt.Sprintf(" AS OF SYSTEM TIME '-%d µs' ", d.Microseconds())
}
type User struct {
Username string
Password string

View File

@@ -0,0 +1,61 @@
package cockroach
import (
"testing"
"time"
)
func TestConfig_Timetravel(t *testing.T) {
type args struct {
d time.Duration
}
tests := []struct {
name string
args args
want string
}{
{
name: "no duration",
args: args{
d: 0,
},
want: " AS OF SYSTEM TIME '-1 µs' ",
},
{
name: "less than microsecond",
args: args{
d: 100 * time.Nanosecond,
},
want: " AS OF SYSTEM TIME '-1 µs' ",
},
{
name: "10 microseconds",
args: args{
d: 10 * time.Microsecond,
},
want: " AS OF SYSTEM TIME '-10 µs' ",
},
{
name: "10 milliseconds",
args: args{
d: 10 * time.Millisecond,
},
want: " AS OF SYSTEM TIME '-10000 µs' ",
},
{
name: "1 second",
args: args{
d: 1 * time.Second,
},
want: " AS OF SYSTEM TIME '-1000000 µs' ",
},
}
for _, tt := range tests {
c := &Config{}
t.Run(tt.name, func(t *testing.T) {
if got := c.Timetravel(tt.args.d); got != tt.want {
t.Errorf("Config.Timetravel() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -19,7 +19,12 @@ func (c *Config) SetConnector(connector dialect.Connector) {
c.connector = connector
}
func Connect(config Config, useAdmin bool) (*sql.DB, error) {
type DB struct {
*sql.DB
dialect.Database
}
func Connect(config Config, useAdmin bool) (*DB, error) {
client, err := config.connector.Connect(useAdmin)
if err != nil {
return nil, err
@@ -29,7 +34,10 @@ func Connect(config Config, useAdmin bool) (*sql.DB, error) {
return nil, errors.ThrowPreconditionFailed(err, "DATAB-0pIWD", "Errors.Database.Connection.Failed")
}
return client, nil
return &DB{
DB: client,
Database: config.connector,
}, nil
}
func DecodeHook(from, to reflect.Value) (interface{}, error) {
@@ -61,7 +69,7 @@ func DecodeHook(from, to reflect.Value) (interface{}, error) {
return Config{connector: connector}, nil
}
func (c Config) Database() string {
func (c Config) DatabaseName() string {
return c.connector.DatabaseName()
}

View File

@@ -3,6 +3,7 @@ package dialect
import (
"database/sql"
"sync"
"time"
)
type Config struct {
@@ -29,10 +30,15 @@ type Matcher interface {
type Connector interface {
Connect(useAdmin bool) (*sql.DB, error)
Password() string
Database
}
type Database interface {
DatabaseName() string
Username() string
Password() string
Type() string
Timetravel(time.Duration) string
}
func Register(matcher Matcher, config Connector, isDefault bool) {

View File

@@ -89,6 +89,10 @@ func (c *Config) Type() string {
return "postgres"
}
func (c *Config) Timetravel(time.Duration) string {
return ""
}
type User struct {
Username string
Password string

View File

@@ -1,16 +1,16 @@
package eventstore
import (
"database/sql"
"time"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore/repository"
z_sql "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
)
type Config struct {
PushTimeout time.Duration
Client *sql.DB
Client *database.DB
repo repository.Repository
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
@@ -19,7 +20,7 @@ var (
type StatementHandlerConfig struct {
handler.ProjectionHandlerConfig
Client *sql.DB
Client *database.DB
SequenceTable string
LockTable string
FailedEventsTable string
@@ -34,7 +35,7 @@ type StatementHandler struct {
*handler.ProjectionHandler
Locker
client *sql.DB
client *database.DB
sequenceTable string
currentSequenceStmt string
updateSequencesBaseStmt string
@@ -74,7 +75,7 @@ func NewStatementHandler(
aggregates: aggregateTypes,
reduces: reduces,
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionName),
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
initCheck: config.InitCheck,
initialized: make(chan bool),
}
@@ -96,7 +97,7 @@ func (h *StatementHandler) SearchQuery(ctx context.Context, instanceIDs []string
return nil, 0, err
}
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit).AllowTimeTravel()
for _, aggregateType := range h.aggregates {
for _, instanceID := range instanceIDs {

View File

@@ -12,6 +12,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/repository"
@@ -114,6 +115,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AllowTimeTravel().
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
@@ -143,6 +145,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AllowTimeTravel().
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
@@ -171,7 +174,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
},
SequenceTable: tt.fields.sequenceTable,
BulkLimit: tt.fields.bulkLimit,
Client: client,
Client: &database.DB{
DB: client,
},
})
h.aggregates = tt.fields.aggregates
@@ -549,7 +554,9 @@ func TestStatementHandler_Update(t *testing.T) {
sequenceTable: "my_sequences",
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, "my_sequences"),
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, "my_sequences"),
client: client,
client: &database.DB{
DB: client,
},
}
h.aggregates = tt.fields.aggregates
@@ -1121,7 +1128,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
ProjectionName: tt.fields.projectionName,
RequeueEvery: 0,
},
Client: client,
Client: &database.DB{
DB: client,
},
FailedEventsTable: tt.fields.failedEventsTable,
MaxFailureCount: tt.fields.maxFailureCount,
},

View File

@@ -194,7 +194,7 @@ func (h *ProjectionHandler) schedule(ctx context.Context) {
var succeededOnce bool
var err error
// get every instance id except empty (system)
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).AddQuery().ExcludedInstanceID("")
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).AllowTimeTravel().AddQuery().ExcludedInstanceID("")
for range h.triggerProjection.C {
if !succeededOnce {
// (re)check if it has succeeded in the meantime

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"os"
"testing"
"time"
"github.com/cockroachdb/cockroach-go/v2/testserver"
"github.com/zitadel/logging"
@@ -14,7 +15,7 @@ import (
)
var (
testCRDBClient *sql.DB
testCRDBClient *database.DB
)
func TestMain(m *testing.M) {
@@ -23,7 +24,11 @@ func TestMain(m *testing.M) {
logging.WithFields("error", err).Fatal("unable to start db")
}
testCRDBClient, err = sql.Open("postgres", ts.PGURL().String())
testCRDBClient = &database.DB{
Database: new(testDB),
}
testCRDBClient.DB, err = sql.Open("postgres", ts.PGURL().String())
if err != nil {
logging.WithFields("error", err).Fatal("unable to connect to db")
}
@@ -39,7 +44,7 @@ func TestMain(m *testing.M) {
ts.Stop()
}()
if err = initDB(testCRDBClient); err != nil {
if err = initDB(testCRDBClient.DB); err != nil {
logging.WithFields("error", err).Fatal("migrations failed")
}
@@ -57,10 +62,20 @@ func initDB(db *sql.DB) error {
})
err := initialise.Init(db,
initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.Database()),
initialise.VerifyGrant(config.Database(), config.Username()))
initialise.VerifyDatabase(config.DatabaseName()),
initialise.VerifyGrant(config.DatabaseName(), config.Username()))
if err != nil {
return err
}
return initialise.VerifyZitadel(db, *config)
}
type testDB struct{}
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
func (*testDB) DatabaseName() string { return "db" }
func (*testDB) Username() string { return "user" }
func (*testDB) Type() string { return "type" }

View File

@@ -2,7 +2,7 @@ package eventstore
import "time"
//ReadModel is the minimum representation of a read model.
// ReadModel is the minimum representation of a read model.
// It implements a basic reducer
// it might be saved in a database or in memory
type ReadModel struct {
@@ -15,14 +15,13 @@ type ReadModel struct {
InstanceID string `json:"-"`
}
//AppendEvents adds all the events to the read model.
// AppendEvents adds all the events to the read model.
// The function doesn't compute the new state of the read model
func (rm *ReadModel) AppendEvents(events ...Event) *ReadModel {
func (rm *ReadModel) AppendEvents(events ...Event) {
rm.Events = append(rm.Events, events...)
return rm
}
//Reduce is the basic implementation of reducer
// Reduce is the basic implementation of reducer
// If this function is extended the extending function should be the last step
func (rm *ReadModel) Reduce() error {
if len(rm.Events) == 0 {

View File

@@ -8,11 +8,12 @@ import (
// SearchQuery defines the which and how data are queried
type SearchQuery struct {
Columns Columns
Limit uint64
Desc bool
Filters [][]*Filter
Tx *sql.Tx
Columns Columns
Limit uint64
Desc bool
Filters [][]*Filter
Tx *sql.Tx
AllowTimeTravel bool
}
// Columns defines which fields of the event are needed for the query

View File

@@ -13,6 +13,7 @@ import (
"github.com/lib/pq"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -97,19 +98,19 @@ const (
)
type CRDB struct {
client *sql.DB
*database.DB
}
func NewCRDB(client *sql.DB) *CRDB {
func NewCRDB(client *database.DB) *CRDB {
return &CRDB{client}
}
func (db *CRDB) Health(ctx context.Context) error { return db.client.Ping() }
func (db *CRDB) Health(ctx context.Context) error { return db.Ping() }
// Push adds all events to the eventstreams of the aggregates.
// This call is transaction save. The transaction will be rolled back if one event fails
func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueConstraints ...*repository.UniqueConstraint) error {
err := crdb.ExecuteTx(ctx, db.client, nil, func(tx *sql.Tx) error {
err := crdb.ExecuteTx(ctx, db.DB.DB, nil, func(tx *sql.Tx) error {
var (
previousAggregateSequence Sequence
@@ -159,7 +160,7 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`)
func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
row := db.client.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID)
row := db.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID)
if row.Err() != nil {
return caos_errs.ThrowInvalidArgument(row.Err(), "SQL-7gtFA", "Errors.InvalidArgument")
}
@@ -168,7 +169,7 @@ func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument")
}
if _, err := db.client.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil {
if _, err := db.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil {
return caos_errs.ThrowInternal(err, "SQL-7gtFA", "Errors.Internal")
}
@@ -249,7 +250,7 @@ func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQ
}
func (db *CRDB) db() *sql.DB {
return db.client
return db.DB.DB
}
func (db *CRDB) orderByEventSequence(desc bool) string {

View File

@@ -437,7 +437,10 @@ func TestCRDB_Push_OneAggregate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{
DB: testCRDBClient,
Database: new(testDB),
},
}
if tt.args.uniqueDataType != "" && tt.args.uniqueDataField != "" {
err := fillUniqueData(tt.args.uniqueDataType, tt.args.uniqueDataField, tt.args.uniqueDataInstanceID)
@@ -561,7 +564,10 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{
DB: testCRDBClient,
Database: new(testDB),
},
}
if err := db.Push(context.Background(), tt.args.events); (err != nil) != tt.res.wantErr {
t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr)
@@ -638,7 +644,7 @@ func TestCRDB_CreateInstance(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{DB: testCRDBClient},
}
if err := db.CreateInstance(context.Background(), tt.args.instanceID); (err != nil) != tt.res.wantErr {
@@ -776,7 +782,10 @@ func TestCRDB_Push_Parallel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{
DB: testCRDBClient,
Database: new(testDB),
},
}
wg := sync.WaitGroup{}
@@ -897,7 +906,10 @@ func TestCRDB_Filter(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{
DB: testCRDBClient,
Database: new(testDB),
},
}
// setup initial data for query
@@ -987,7 +999,10 @@ func TestCRDB_LatestSequence(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{
DB: testCRDBClient,
Database: new(testDB),
},
}
// setup initial data for query
@@ -1131,7 +1146,10 @@ func TestCRDB_Push_ResourceOwner(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: testCRDBClient,
DB: &database.DB{
DB: testCRDBClient,
Database: new(testDB),
},
}
if err := db.Push(context.Background(), tt.args.events); err != nil {
t.Errorf("CRDB.Push() error = %v", err)

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"os"
"testing"
"time"
"github.com/cockroachdb/cockroach-go/v2/testserver"
"github.com/zitadel/logging"
@@ -53,8 +54,8 @@ func initDB(db *sql.DB) error {
err := initialise.Init(db,
initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.Database()),
initialise.VerifyGrant(config.Database(), config.Username()))
initialise.VerifyDatabase(config.DatabaseName()),
initialise.VerifyGrant(config.DatabaseName(), config.Username()))
if err != nil {
return err
}
@@ -66,3 +67,13 @@ func fillUniqueData(unique_type, field, instanceID string) error {
_, err := testCRDBClient.Exec("INSERT INTO eventstore.unique_constraints (unique_type, unique_field, instance_id) VALUES ($1, $2, $3)", unique_type, field, instanceID)
return err
}
type testDB struct{}
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
func (*testDB) DatabaseName() string { return "db" }
func (*testDB) Username() string { return "user" }
func (*testDB) Type() string { return "type" }

View File

@@ -10,6 +10,8 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database/dialect"
z_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -24,6 +26,7 @@ type querier interface {
instanceIDsQuery() string
db() *sql.DB
orderByEventSequence(desc bool) string
dialect.Database
}
type scan func(dest ...interface{}) error
@@ -34,6 +37,11 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
if where == "" || query == "" {
return z_errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
if searchQuery.Tx == nil {
if travel := prepareTimeTravel(ctx, criteria, searchQuery.AllowTimeTravel); travel != "" {
query += travel
}
}
query += where
if searchQuery.Columns == repository.ColumnsEvent {
@@ -85,6 +93,14 @@ func prepareColumns(criteria querier, columns repository.Columns) (string, func(
}
}
func prepareTimeTravel(ctx context.Context, criteria querier, allow bool) string {
if !allow {
return ""
}
took := call.Took(ctx)
return criteria.Timetravel(took)
}
func maxSequenceScanner(row scan, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence)
if !ok {

View File

@@ -10,6 +10,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -537,7 +538,10 @@ func Test_query_events_with_crdb(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{
client: tt.fields.client,
DB: &database.DB{
DB: tt.fields.client,
Database: new(testDB),
},
}
// setup initial data for query
@@ -657,6 +661,36 @@ func Test_query_events_mocked(t *testing.T) {
wantErr: false,
},
},
{
name: "with limit and order by desc as of system time",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Limit: 5,
AllowTimeTravel: true,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)},
),
},
res: res{
wantErr: false,
},
},
{
name: "error sql conn closed",
args: args{
@@ -786,9 +820,11 @@ func Test_query_events_mocked(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
crdb := &CRDB{}
crdb := &CRDB{DB: &database.DB{
Database: new(testDB),
}}
if tt.fields.mock != nil {
crdb.client = tt.fields.mock.client
crdb.DB.DB = tt.fields.mock.client
}
err := query(context.Background(), crdb, tt.args.query, tt.args.dest)

View File

@@ -12,14 +12,15 @@ import (
// SearchQueryBuilder represents the builder for your filter
// if invalid data are set the filter will fail
type SearchQueryBuilder struct {
columns repository.Columns
limit uint64
desc bool
resourceOwner string
instanceID string
editorUser string
queries []*SearchQuery
tx *sql.Tx
columns repository.Columns
limit uint64
desc bool
resourceOwner string
instanceID string
editorUser string
queries []*SearchQuery
tx *sql.Tx
allowTimeTravel bool
}
type SearchQuery struct {
@@ -130,6 +131,13 @@ func (builder *SearchQueryBuilder) EditorUser(id string) *SearchQueryBuilder {
return builder
}
// AllowTimeTravel activates the time travel feature of the database if supported
// The queries will be made based on the call time
func (builder *SearchQueryBuilder) AllowTimeTravel() *SearchQueryBuilder {
builder.allowTimeTravel = true
return builder
}
// AddQuery creates a new sub query.
// All fields in the sub query are AND-connected in the storage request.
// Multiple sub queries are OR-connected in the storage request.
@@ -264,11 +272,12 @@ func (builder *SearchQueryBuilder) build(instanceID string) (*repository.SearchQ
}
return &repository.SearchQuery{
Columns: builder.columns,
Limit: builder.limit,
Desc: builder.desc,
Filters: filters,
Tx: builder.tx,
Columns: builder.columns,
Limit: builder.limit,
Desc: builder.desc,
Filters: filters,
Tx: builder.tx,
AllowTimeTravel: builder.allowTimeTravel,
}, nil
}

View File

@@ -2,8 +2,8 @@ package v1
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore/v1/internal/repository"
z_sql "github.com/zitadel/zitadel/internal/eventstore/v1/internal/repository/sql"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -22,7 +22,7 @@ type eventstore struct {
repo repository.Repository
}
func Start(db *sql.DB) (Eventstore, error) {
func Start(db *database.DB) (Eventstore, error) {
return &eventstore{
repo: z_sql.Start(db),
}, nil

View File

@@ -1,10 +1,10 @@
package sql
import (
"database/sql"
"github.com/zitadel/zitadel/internal/database"
)
func Start(client *sql.DB) *SQL {
func Start(client *database.DB) *SQL {
return &SQL{
client: client,
}

View File

@@ -12,7 +12,7 @@ import (
)
const (
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE \( aggregate_type = \$1`
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1`
)
var (
@@ -172,14 +172,14 @@ func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
}
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`).
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
WithArgs(aggregateType).
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
return db
}
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`).
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
WithArgs(aggregateType).WillReturnError(err)
return db
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@@ -16,16 +17,16 @@ type Querier interface {
}
func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
return filter(db.client, searchQuery)
return filter(ctx, db.client, searchQuery)
}
func filter(querier Querier, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
query, limit, values, rowScanner := buildQuery(searchQuery)
func filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
query, limit, values, rowScanner := buildQuery(ctx, db, searchQuery)
if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
rows, err := querier.Query(query, values...)
rows, err := db.Query(query, values...)
if err != nil {
logging.New().WithError(err).Info("query failed")
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
@@ -48,7 +49,7 @@ func filter(querier Querier, searchQuery *es_models.SearchQueryFactory) (events
}
func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
query, _, values, rowScanner := buildQuery(queryFactory)
query, _, values, rowScanner := buildQuery(ctx, db.client, queryFactory)
if query == "" {
return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
@@ -63,7 +64,7 @@ func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.Searc
}
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) {
query, _, values, rowScanner := buildQuery(queryFactory)
query, _, values, rowScanner := buildQuery(ctx, db.client, queryFactory)
if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
}

View File

@@ -6,6 +6,7 @@ import (
"math"
"testing"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
)
@@ -122,7 +123,7 @@ func TestSQL_Filter(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql := &SQL{
client: tt.fields.client.sqlClient,
client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)},
}
events, err := sql.Filter(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.res.wantErr {
@@ -217,7 +218,7 @@ func TestSQL_LatestSequence(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql := &SQL{
client: tt.fields.client.sqlClient,
client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)},
}
sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.res.wantErr {

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"database/sql"
"errors"
"fmt"
@@ -9,6 +10,8 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database/dialect"
z_errors "github.com/zitadel/zitadel/internal/errors"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
)
@@ -30,7 +33,7 @@ const (
" FROM eventstore.events"
)
func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scan, dest interface{}) error) {
func buildQuery(ctx context.Context, db dialect.Database, queryFactory *es_models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scan, dest interface{}) error) {
searchQuery, err := queryFactory.Build()
if err != nil {
logging.New().WithError(err).Warn("search query factory invalid")
@@ -41,6 +44,10 @@ func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit
if where == "" || query == "" {
return "", 0, nil, nil
}
if travel := db.Timetravel(call.Took(ctx)); travel != "" {
query += travel
}
query += where
if searchQuery.Columns == es_models.Columns_Event {

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"database/sql"
"reflect"
"testing"
@@ -435,7 +436,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user")},
},
@@ -446,7 +447,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").Factory(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence LIMIT $2",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE ( aggregate_type = $1 ) ORDER BY event_sequence LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
@@ -458,7 +459,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory().Limit(5).OrderDesc().AddQuery().AggregateTypes("user").Factory(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC LIMIT $2",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
@@ -466,8 +467,10 @@ func Test_buildQuery(t *testing.T) {
},
}
for _, tt := range tests {
ctx := context.Background()
db := new(testDB)
t.Run(tt.name, func(t *testing.T) {
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory)
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(ctx, db, tt.args.queryFactory)
if gotQuery != tt.res.query {
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
}
@@ -489,3 +492,13 @@ func Test_buildQuery(t *testing.T) {
})
}
}
type testDB struct{}
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
func (*testDB) DatabaseName() string { return "db" }
func (*testDB) Username() string { return "user" }
func (*testDB) Type() string { return "type" }

View File

@@ -2,11 +2,12 @@ package sql
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/internal/database"
)
type SQL struct {
client *sql.DB
client *database.DB
}
func (db *SQL) Health(ctx context.Context) error {

View File

@@ -2,7 +2,6 @@ package access
import (
"context"
"database/sql"
"fmt"
"net/http"
"strings"
@@ -12,7 +11,9 @@ import (
"github.com/zitadel/logging"
"google.golang.org/grpc/codes"
"github.com/zitadel/zitadel/internal/api/call"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/database"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/repository/quota"
@@ -36,10 +37,10 @@ var _ logstore.UsageQuerier = (*databaseLogStorage)(nil)
var _ logstore.LogCleanupper = (*databaseLogStorage)(nil)
type databaseLogStorage struct {
dbClient *sql.DB
dbClient *database.DB
}
func NewDatabaseLogStorage(dbClient *sql.DB) *databaseLogStorage {
func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage {
return &databaseLogStorage{dbClient: dbClient}
}
@@ -98,12 +99,11 @@ func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord
return nil
}
// TODO: AS OF SYSTEM TIME
func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) {
stmt, args, err := squirrel.Select(
fmt.Sprintf("count(%s)", accessInstanceIdCol),
).
From(accessLogsTable).
From(accessLogsTable + l.dbClient.Timetravel(call.Took(ctx))).
Where(squirrel.And{
squirrel.Eq{accessInstanceIdCol: instanceId},
squirrel.GtOrEq{accessTimestampCol: start},

View File

@@ -2,13 +2,14 @@ package execution
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Masterminds/squirrel"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/repository/quota"
@@ -29,10 +30,10 @@ var _ logstore.UsageQuerier = (*databaseLogStorage)(nil)
var _ logstore.LogCleanupper = (*databaseLogStorage)(nil)
type databaseLogStorage struct {
dbClient *sql.DB
dbClient *database.DB
}
func NewDatabaseLogStorage(dbClient *sql.DB) *databaseLogStorage {
func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage {
return &databaseLogStorage{dbClient: dbClient}
}
@@ -91,12 +92,11 @@ func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord
return nil
}
// TODO: AS OF SYSTEM TIME
func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) {
stmt, args, err := squirrel.Select(
fmt.Sprintf("COALESCE(SUM(%s)::INT,0)", executionTookCol),
).
From(executionLogsTable).
From(executionLogsTable + l.dbClient.Timetravel(call.Took(ctx))).
Where(squirrel.And{
squirrel.Eq{executionInstanceIdCol: instanceId},
squirrel.GtOrEq{executionTimestampCol: start},

View File

@@ -7,9 +7,12 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/repository/quota"
)
const handleThresholdTimeout = time.Minute
type QuotaQuerier interface {
GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (config *quota.AddedEvent, periodStart time.Time, err error)
GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, used uint64) ([]*quota.NotifiedEvent, error)
@@ -94,17 +97,29 @@ func (s *Service) Limit(ctx context.Context, instanceID string) *uint64 {
return nil
}
go s.handleThresholds(ctx, quota, periodStart, usage)
var remaining *uint64
if quota.Limit {
r := uint64(math.Max(0, float64(quota.Amount)-float64(usage)))
remaining = &r
}
notifications, err := s.quotaQuerier.GetDueQuotaNotifications(ctx, quota, periodStart, usage)
if err != nil {
return remaining
}
err = s.usageReporter.Report(ctx, notifications)
return remaining
}
func (s *Service) handleThresholds(ctx context.Context, quota *quota.AddedEvent, periodStart time.Time, usage uint64) {
var err error
defer func() {
logging.OnError(err).Warn("handling quota thresholds failed")
}()
detatchedCtx, cancel := context.WithTimeout(authz.Detach(ctx), handleThresholdTimeout)
defer cancel()
notifications, err := s.quotaQuerier.GetDueQuotaNotifications(detatchedCtx, quota, periodStart, usage)
if err != nil || len(notifications) == 0 {
return
}
err = s.usageReporter.Report(detatchedCtx, notifications)
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -117,7 +118,7 @@ func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQuerie
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareActionsQuery()
query, scan := prepareActionsQuery(ctx, q.client)
eq := sq.Eq{
ActionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}
@@ -145,7 +146,7 @@ func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, wi
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareActionQuery()
stmt, scan := prepareActionQuery(ctx, q.client)
eq := sq.Eq{
ActionColumnID.identifier(): id,
ActionColumnResourceOwner.identifier(): orgID,
@@ -179,7 +180,7 @@ func NewActionIDSearchQuery(id string) (SearchQuery, error) {
return NewTextQuery(ActionColumnID, id, TextEquals)
}
func prepareActionsQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*Actions, error)) {
func prepareActionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*Actions, error)) {
return sq.Select(
ActionColumnID.identifier(),
ActionColumnCreationDate.identifier(),
@@ -192,7 +193,8 @@ func prepareActionsQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*Actions, er
ActionColumnTimeout.identifier(),
ActionColumnAllowedToFail.identifier(),
countColumn.identifier(),
).From(actionTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(actionTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Actions, error) {
actions := make([]*Action, 0)
var count uint64
@@ -230,7 +232,7 @@ func prepareActionsQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*Actions, er
}
}
func prepareActionQuery() (sq.SelectBuilder, func(row *sql.Row) (*Action, error)) {
func prepareActionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) (*Action, error)) {
return sq.Select(
ActionColumnID.identifier(),
ActionColumnCreationDate.identifier(),
@@ -242,7 +244,8 @@ func prepareActionQuery() (sq.SelectBuilder, func(row *sql.Row) (*Action, error)
ActionColumnScript.identifier(),
ActionColumnTimeout.identifier(),
ActionColumnAllowedToFail.identifier(),
).From(actionTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(actionTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*Action, error) {
action := new(Action)
err := row.Scan(

View File

@@ -8,6 +8,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -70,7 +71,7 @@ func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID s
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareFlowQuery(flowType)
query, scan := prepareFlowQuery(ctx, q.client, flowType)
eq := sq.Eq{
FlowsTriggersColumnFlowType.identifier(): flowType,
FlowsTriggersColumnResourceOwner.identifier(): orgID,
@@ -95,7 +96,7 @@ func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flow
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareTriggerActionsQuery()
stmt, scan := prepareTriggerActionsQuery(ctx, q.client)
eq := sq.Eq{
FlowsTriggersColumnFlowType.identifier(): flowType,
FlowsTriggersColumnTriggerType.identifier(): triggerType,
@@ -122,7 +123,7 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, w
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareFlowTypesQuery()
stmt, scan := prepareFlowTypesQuery(ctx, q.client)
eq := sq.Eq{
FlowsTriggersColumnActionID.identifier(): actionID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -143,11 +144,11 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, w
return scan(rows)
}
func prepareFlowTypesQuery() (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) {
func prepareFlowTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) {
return sq.Select(
FlowsTriggersColumnFlowType.identifier(),
).
From(flowsTriggersTable.identifier()).
From(flowsTriggersTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) ([]domain.FlowType, error) {
types := []domain.FlowType{}
@@ -166,7 +167,7 @@ func prepareFlowTypesQuery() (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowTy
}
func prepareTriggerActionsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]*Action, error)) {
func prepareTriggerActionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]*Action, error)) {
return sq.Select(
ActionColumnID.identifier(),
ActionColumnCreationDate.identifier(),
@@ -180,7 +181,7 @@ func prepareTriggerActionsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]*Action,
ActionColumnTimeout.identifier(),
).
From(flowsTriggersTable.name).
LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)).
LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) ([]*Action, error) {
actions := make([]*Action, 0)
@@ -212,7 +213,7 @@ func prepareTriggerActionsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]*Action,
}
}
func prepareFlowQuery(flowType domain.FlowType) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
func prepareFlowQuery(ctx context.Context, db prepareDatabase, flowType domain.FlowType) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return sq.Select(
ActionColumnID.identifier(),
ActionColumnCreationDate.identifier(),
@@ -232,7 +233,7 @@ func prepareFlowQuery(flowType domain.FlowType) (sq.SelectBuilder, func(*sql.Row
FlowsTriggersColumnResourceOwner.identifier(),
).
From(flowsTriggersTable.name).
LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)).
LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Flow, error) {
flow := &Flow{

View File

@@ -1,6 +1,7 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@@ -14,6 +15,82 @@ import (
"github.com/zitadel/zitadel/internal/domain"
)
var (
prepareFlowStmt = `SELECT projections.actions3.id,` +
` projections.actions3.creation_date,` +
` projections.actions3.change_date,` +
` projections.actions3.resource_owner,` +
` projections.actions3.action_state,` +
` projections.actions3.sequence,` +
` projections.actions3.name,` +
` projections.actions3.script,` +
` projections.actions3.allowed_to_fail,` +
` projections.actions3.timeout,` +
` projections.flow_triggers2.trigger_type,` +
` projections.flow_triggers2.trigger_sequence,` +
` projections.flow_triggers2.flow_type,` +
` projections.flow_triggers2.change_date,` +
` projections.flow_triggers2.sequence,` +
` projections.flow_triggers2.resource_owner` +
` FROM projections.flow_triggers2` +
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`
prepareFlowCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
// flow
"trigger_type",
"trigger_sequence",
"flow_type",
"change_date",
"sequence",
"resource_owner",
}
prepareTriggerActionStmt = `SELECT projections.actions3.id,` +
` projections.actions3.creation_date,` +
` projections.actions3.change_date,` +
` projections.actions3.resource_owner,` +
` projections.actions3.action_state,` +
` projections.actions3.sequence,` +
` projections.actions3.name,` +
` projections.actions3.script,` +
` projections.actions3.allowed_to_fail,` +
` projections.actions3.timeout` +
` FROM projections.flow_triggers2` +
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`
prepareTriggerActionCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
}
prepareFlowTypeStmt = `SELECT projections.flow_triggers2.flow_type` +
` FROM projections.flow_triggers2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareFlowTypeCols = []string{
"flow_type",
}
)
func Test_FlowPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -27,29 +104,12 @@ func Test_FlowPrepares(t *testing.T) {
}{
{
name: "prepareFlowQuery no result",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(domain.FlowTypeExternalAuthentication)
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication)
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout,`+
` projections.flow_triggers2.trigger_type,`+
` projections.flow_triggers2.trigger_sequence,`+
` projections.flow_triggers2.flow_type,`+
` projections.flow_triggers2.change_date,`+
` projections.flow_triggers2.sequence,`+
` projections.flow_triggers2.resource_owner`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
regexp.QuoteMeta(prepareFlowStmt),
nil,
nil,
),
@@ -61,48 +121,13 @@ func Test_FlowPrepares(t *testing.T) {
},
{
name: "prepareFlowQuery one action",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(domain.FlowTypeExternalAuthentication)
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication)
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout,`+
` projections.flow_triggers2.trigger_type,`+
` projections.flow_triggers2.trigger_sequence,`+
` projections.flow_triggers2.flow_type,`+
` projections.flow_triggers2.change_date,`+
` projections.flow_triggers2.sequence,`+
` projections.flow_triggers2.resource_owner`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
//flow
"trigger_type",
"trigger_sequence",
"flow_type",
"change_date",
"sequence",
"resource_owner",
},
regexp.QuoteMeta(prepareFlowStmt),
prepareFlowCols,
[][]driver.Value{
{
"action-id",
@@ -150,48 +175,13 @@ func Test_FlowPrepares(t *testing.T) {
},
{
name: "prepareFlowQuery multiple actions",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(domain.FlowTypeExternalAuthentication)
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication)
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout,`+
` projections.flow_triggers2.trigger_type,`+
` projections.flow_triggers2.trigger_sequence,`+
` projections.flow_triggers2.flow_type,`+
` projections.flow_triggers2.change_date,`+
` projections.flow_triggers2.sequence,`+
` projections.flow_triggers2.resource_owner`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
//flow
"trigger_type",
"trigger_sequence",
"flow_type",
"change_date",
"sequence",
"resource_owner",
},
regexp.QuoteMeta(prepareFlowStmt),
prepareFlowCols,
[][]driver.Value{
{
"action-id-pre",
@@ -271,48 +261,13 @@ func Test_FlowPrepares(t *testing.T) {
},
{
name: "prepareFlowQuery no action",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(domain.FlowTypeExternalAuthentication)
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication)
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout,`+
` projections.flow_triggers2.trigger_type,`+
` projections.flow_triggers2.trigger_sequence,`+
` projections.flow_triggers2.flow_type,`+
` projections.flow_triggers2.change_date,`+
` projections.flow_triggers2.sequence,`+
` projections.flow_triggers2.resource_owner`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
//flow
"trigger_type",
"trigger_sequence",
"flow_type",
"change_date",
"sequence",
"resource_owner",
},
regexp.QuoteMeta(prepareFlowStmt),
prepareFlowCols,
[][]driver.Value{
{
nil,
@@ -345,29 +300,12 @@ func Test_FlowPrepares(t *testing.T) {
},
{
name: "prepareFlowQuery sql err",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(domain.FlowTypeExternalAuthentication)
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication)
},
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout,`+
` projections.flow_triggers2.trigger_type,`+
` projections.flow_triggers2.trigger_sequence,`+
` projections.flow_triggers2.flow_type,`+
` projections.flow_triggers2.change_date,`+
` projections.flow_triggers2.sequence,`+
` projections.flow_triggers2.resource_owner`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
regexp.QuoteMeta(prepareFlowStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -384,18 +322,7 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareTriggerActionsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
regexp.QuoteMeta(prepareTriggerActionStmt),
nil,
nil,
),
@@ -407,30 +334,8 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareTriggerActionsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
},
regexp.QuoteMeta(prepareTriggerActionStmt),
prepareTriggerActionCols,
[][]driver.Value{
{
"action-id",
@@ -467,30 +372,8 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareTriggerActionsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"state",
"sequence",
"name",
"script",
"allowed_to_fail",
"timeout",
},
regexp.QuoteMeta(prepareTriggerActionStmt),
prepareTriggerActionCols,
[][]driver.Value{
{
"action-id-1",
@@ -551,18 +434,7 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareTriggerActionsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.action_state,`+
` projections.actions3.sequence,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.allowed_to_fail,`+
` projections.actions3.timeout`+
` FROM projections.flow_triggers2`+
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id`),
regexp.QuoteMeta(prepareTriggerActionStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -579,8 +451,7 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareFlowTypesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.flow_triggers2.flow_type`+
` FROM projections.flow_triggers2`),
regexp.QuoteMeta(prepareFlowTypeStmt),
nil,
nil,
),
@@ -592,11 +463,8 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareFlowTypesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.flow_triggers2.flow_type`+
` FROM projections.flow_triggers2`),
[]string{
"flow_type",
},
regexp.QuoteMeta(prepareFlowTypeStmt),
prepareFlowTypeCols,
[][]driver.Value{
{
domain.FlowTypeExternalAuthentication,
@@ -613,11 +481,8 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareFlowTypesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.flow_triggers2.flow_type`+
` FROM projections.flow_triggers2`),
[]string{
"flow_type",
},
regexp.QuoteMeta(prepareFlowTypeStmt),
prepareFlowTypeCols,
[][]driver.Value{
{
domain.FlowTypeExternalAuthentication,
@@ -638,8 +503,7 @@ func Test_FlowPrepares(t *testing.T) {
prepare: prepareFlowTypesQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.flow_triggers2.flow_type`+
` FROM projections.flow_triggers2`),
regexp.QuoteMeta(prepareFlowTypeStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -654,7 +518,7 @@ func Test_FlowPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -13,6 +13,60 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
prepareActionsStmt = `SELECT projections.actions3.id,` +
` projections.actions3.creation_date,` +
` projections.actions3.change_date,` +
` projections.actions3.resource_owner,` +
` projections.actions3.sequence,` +
` projections.actions3.action_state,` +
` projections.actions3.name,` +
` projections.actions3.script,` +
` projections.actions3.timeout,` +
` projections.actions3.allowed_to_fail,` +
` COUNT(*) OVER ()` +
` FROM projections.actions3` +
` AS OF SYSTEM TIME '-1 ms'`
prepareActionsCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"action_state",
"name",
"script",
"timeout",
"allowed_to_fail",
"count",
}
prepareActionStmt = `SELECT projections.actions3.id,` +
` projections.actions3.creation_date,` +
` projections.actions3.change_date,` +
` projections.actions3.resource_owner,` +
` projections.actions3.sequence,` +
` projections.actions3.action_state,` +
` projections.actions3.name,` +
` projections.actions3.script,` +
` projections.actions3.timeout,` +
` projections.actions3.allowed_to_fail` +
` FROM projections.actions3` +
` AS OF SYSTEM TIME '-1 ms'`
prepareActionCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"action_state",
"name",
"script",
"timeout",
"allowed_to_fail",
}
)
func Test_ActionPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -29,18 +83,7 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail,`+
` COUNT(*) OVER ()`+
` FROM projections.actions3`),
regexp.QuoteMeta(prepareActionsStmt),
nil,
nil,
),
@@ -52,31 +95,8 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail,`+
` COUNT(*) OVER ()`+
` FROM projections.actions3`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"action_state",
"name",
"script",
"timeout",
"allowed_to_fail",
"count",
},
regexp.QuoteMeta(prepareActionsStmt),
prepareActionsCols,
[][]driver.Value{
{
"id",
@@ -118,31 +138,8 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail,`+
` COUNT(*) OVER ()`+
` FROM projections.actions3`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"action_state",
"name",
"script",
"timeout",
"allowed_to_fail",
"count",
},
regexp.QuoteMeta(prepareActionsStmt),
prepareActionsCols,
[][]driver.Value{
{
"id-1",
@@ -208,18 +205,7 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail,`+
` COUNT(*) OVER ()`+
` FROM projections.actions3`),
regexp.QuoteMeta(prepareActionsStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -236,17 +222,7 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail`+
` FROM projections.actions3`),
regexp.QuoteMeta(prepareActionStmt),
nil,
nil,
),
@@ -264,29 +240,8 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail`+
` FROM projections.actions3`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"action_state",
"name",
"script",
"timeout",
"allowed_to_fail",
},
regexp.QuoteMeta(prepareActionStmt),
prepareActionCols,
[]driver.Value{
"id",
testNow,
@@ -319,17 +274,7 @@ func Test_ActionPrepares(t *testing.T) {
prepare: prepareActionQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.actions3.id,`+
` projections.actions3.creation_date,`+
` projections.actions3.change_date,`+
` projections.actions3.resource_owner,`+
` projections.actions3.sequence,`+
` projections.actions3.action_state,`+
` projections.actions3.name,`+
` projections.actions3.script,`+
` projections.actions3.timeout,`+
` projections.actions3.allowed_to_fail`+
` FROM projections.actions3`),
regexp.QuoteMeta(prepareActionStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -344,7 +289,7 @@ func Test_ActionPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
@@ -250,7 +251,7 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo
projection.AppProjection.Trigger(ctx)
}
stmt, scan := prepareAppQuery()
stmt, scan := prepareAppQuery(ctx, q.client)
eq := sq.Eq{
AppColumnID.identifier(): appID,
AppColumnProjectID.identifier(): projectID,
@@ -272,7 +273,7 @@ func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bo
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareAppQuery()
stmt, scan := prepareAppQuery(ctx, q.client)
eq := sq.Eq{
AppColumnID.identifier(): appID,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -293,7 +294,7 @@ func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOw
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareAppQuery()
stmt, scan := prepareAppQuery(ctx, q.client)
eq := sq.Eq{
AppSAMLConfigColumnEntityID.identifier(): entityID,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -314,7 +315,7 @@ func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwner
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareProjectByAppQuery()
stmt, scan := prepareProjectByAppQuery(ctx, q.client)
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[ProjectColumnOwnerRemoved.identifier()] = false
@@ -339,7 +340,7 @@ func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, w
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareProjectIDByAppQuery()
stmt, scan := prepareProjectIDByAppQuery(ctx, q.client)
eq := sq.Eq{
AppOIDCConfigColumnClientID.identifier(): appID,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -360,7 +361,7 @@ func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withO
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareProjectIDByAppQuery()
stmt, scan := prepareProjectIDByAppQuery(ctx, q.client)
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[AppColumnOwnerRemoved.identifier()] = false
@@ -386,7 +387,7 @@ func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwne
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareProjectByAppQuery()
stmt, scan := prepareProjectByAppQuery(ctx, q.client)
eq := sq.Eq{
AppOIDCConfigColumnClientID.identifier(): id,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -407,7 +408,7 @@ func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOw
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareAppQuery()
stmt, scan := prepareAppQuery(ctx, q.client)
eq := sq.Eq{
AppOIDCConfigColumnClientID.identifier(): clientID,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -428,7 +429,7 @@ func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerR
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareAppQuery()
stmt, scan := prepareAppQuery(ctx, q.client)
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[AppColumnOwnerRemoved.identifier()] = false
@@ -452,7 +453,7 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareAppsQuery()
query, scan := prepareAppsQuery(ctx, q.client)
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[AppColumnOwnerRemoved.identifier()] = false
@@ -478,7 +479,7 @@ func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareClientIDsQuery()
query, scan := prepareClientIDsQuery(ctx, q.client)
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[AppColumnOwnerRemoved.identifier()] = false
@@ -503,7 +504,7 @@ func NewAppProjectIDSearchQuery(id string) (SearchQuery, error) {
return NewTextQuery(AppColumnProjectID, id, TextEquals)
}
func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
func prepareAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
return sq.Select(
AppColumnID.identifier(),
AppColumnName.identifier(),
@@ -542,7 +543,7 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) {
app := new(App)
@@ -604,13 +605,13 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
}
}
func prepareProjectIDByAppQuery() (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) {
func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) {
return sq.Select(
AppColumnProjectID.identifier(),
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (projectID string, err error) {
err = row.Scan(
&projectID,
@@ -627,7 +628,7 @@ func prepareProjectIDByAppQuery() (sq.SelectBuilder, func(*sql.Row) (projectID s
}
}
func prepareProjectByAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, error)) {
func prepareProjectByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Project, error)) {
return sq.Select(
ProjectColumnID.identifier(),
ProjectColumnCreationDate.identifier(),
@@ -644,7 +645,7 @@ func prepareProjectByAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, err
Join(join(AppColumnProjectID, ProjectColumnID)).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*Project, error) {
p := new(Project)
@@ -671,7 +672,7 @@ func prepareProjectByAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, err
}
}
func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
return sq.Select(
AppColumnID.identifier(),
AppColumnName.identifier(),
@@ -711,7 +712,7 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Rows) (*Apps, error) {
apps := &Apps{Apps: []*App{}}
@@ -777,13 +778,13 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
}
}
func prepareClientIDsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]string, error)) {
func prepareClientIDsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]string, error)) {
return sq.Select(
AppAPIConfigColumnClientID.identifier(),
AppOIDCConfigColumnClientID.identifier(),
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) ([]string, error) {
ids := database.StringArray{}

View File

@@ -52,7 +52,8 @@ var (
` FROM projections.apps4` +
` LEFT JOIN projections.apps4_api_configs ON projections.apps4.id = projections.apps4_api_configs.app_id AND projections.apps4.instance_id = projections.apps4_api_configs.instance_id` +
` LEFT JOIN projections.apps4_oidc_configs ON projections.apps4.id = projections.apps4_oidc_configs.app_id AND projections.apps4.instance_id = projections.apps4_oidc_configs.instance_id` +
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id`)
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedAppsQuery = regexp.QuoteMeta(`SELECT projections.apps4.id,` +
` projections.apps4.name,` +
` projections.apps4.project_id,` +
@@ -91,17 +92,20 @@ var (
` FROM projections.apps4` +
` LEFT JOIN projections.apps4_api_configs ON projections.apps4.id = projections.apps4_api_configs.app_id AND projections.apps4.instance_id = projections.apps4_api_configs.instance_id` +
` LEFT JOIN projections.apps4_oidc_configs ON projections.apps4.id = projections.apps4_oidc_configs.app_id AND projections.apps4.instance_id = projections.apps4_oidc_configs.instance_id` +
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id`)
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedAppIDsQuery = regexp.QuoteMeta(`SELECT projections.apps4_api_configs.client_id,` +
` projections.apps4_oidc_configs.client_id` +
` FROM projections.apps4` +
` LEFT JOIN projections.apps4_api_configs ON projections.apps4.id = projections.apps4_api_configs.app_id AND projections.apps4.instance_id = projections.apps4_api_configs.instance_id` +
` LEFT JOIN projections.apps4_oidc_configs ON projections.apps4.id = projections.apps4_oidc_configs.app_id AND projections.apps4.instance_id = projections.apps4_oidc_configs.instance_id`)
` LEFT JOIN projections.apps4_oidc_configs ON projections.apps4.id = projections.apps4_oidc_configs.app_id AND projections.apps4.instance_id = projections.apps4_oidc_configs.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedProjectIDByAppQuery = regexp.QuoteMeta(`SELECT projections.apps4.project_id` +
` FROM projections.apps4` +
` LEFT JOIN projections.apps4_api_configs ON projections.apps4.id = projections.apps4_api_configs.app_id AND projections.apps4.instance_id = projections.apps4_api_configs.instance_id` +
` LEFT JOIN projections.apps4_oidc_configs ON projections.apps4.id = projections.apps4_oidc_configs.app_id AND projections.apps4.instance_id = projections.apps4_oidc_configs.instance_id` +
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id`)
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedProjectByAppQuery = regexp.QuoteMeta(`SELECT projections.projects3.id,` +
` projections.projects3.creation_date,` +
` projections.projects3.change_date,` +
@@ -117,7 +121,8 @@ var (
` JOIN projections.apps4 ON projections.projects3.id = projections.apps4.project_id AND projections.projects3.instance_id = projections.apps4.instance_id` +
` LEFT JOIN projections.apps4_api_configs ON projections.apps4.id = projections.apps4_api_configs.app_id AND projections.apps4.instance_id = projections.apps4_api_configs.instance_id` +
` LEFT JOIN projections.apps4_oidc_configs ON projections.apps4.id = projections.apps4_oidc_configs.app_id AND projections.apps4.instance_id = projections.apps4_oidc_configs.instance_id` +
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id`)
` LEFT JOIN projections.apps4_saml_configs ON projections.apps4.id = projections.apps4_saml_configs.app_id AND projections.apps4.instance_id = projections.apps4_saml_configs.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
appCols = database.StringArray{
"id",
@@ -1009,7 +1014,7 @@ func Test_AppsPrepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
@@ -1628,7 +1633,7 @@ func Test_AppPrepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
@@ -1714,7 +1719,7 @@ func Test_AppIDsPrepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
@@ -1780,7 +1785,7 @@ func Test_ProjectIDByAppPrepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
@@ -1978,7 +1983,7 @@ func Test_ProjectByAppPrepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -129,7 +130,7 @@ func (q *Queries) SearchAuthNKeys(ctx context.Context, queries *AuthNKeySearchQu
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareAuthNKeysQuery()
query, scan := prepareAuthNKeysQuery(ctx, q.client)
query = queries.toQuery(query)
eq := sq.Eq{
AuthNKeyColumnEnabled.identifier(): true,
@@ -159,7 +160,7 @@ func (q *Queries) SearchAuthNKeysData(ctx context.Context, queries *AuthNKeySear
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareAuthNKeysDataQuery()
query, scan := prepareAuthNKeysDataQuery(ctx, q.client)
query = queries.toQuery(query)
eq := sq.Eq{
AuthNKeyColumnEnabled.identifier(): true,
@@ -193,7 +194,7 @@ func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, i
projection.AuthNKeyProjection.Trigger(ctx)
}
query, scan := prepareAuthNKeyQuery()
query, scan := prepareAuthNKeyQuery(ctx, q.client)
for _, q := range queries {
query = q.toQuery(query)
}
@@ -218,7 +219,7 @@ func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareAuthNKeyPublicKeyQuery()
stmt, scan := prepareAuthNKeyPublicKeyQuery(ctx, q.client)
eq := sq.And{
sq.Eq{
AuthNKeyColumnID.identifier(): id,
@@ -265,7 +266,7 @@ func NewAuthNKeyObjectIDQuery(id string) (SearchQuery, error) {
return NewTextQuery(AuthNKeyColumnObjectID, id, TextEquals)
}
func prepareAuthNKeysQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys, error)) {
func prepareAuthNKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys, error)) {
return sq.Select(
AuthNKeyColumnID.identifier(),
AuthNKeyColumnCreationDate.identifier(),
@@ -275,7 +276,8 @@ func prepareAuthNKeysQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys
AuthNKeyColumnExpiration.identifier(),
AuthNKeyColumnType.identifier(),
countColumn.identifier(),
).From(authNKeyTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*AuthNKeys, error) {
authNKeys := make([]*AuthNKey, 0)
var count uint64
@@ -310,7 +312,7 @@ func prepareAuthNKeysQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys
}
}
func prepareAuthNKeyQuery() (sq.SelectBuilder, func(row *sql.Row) (*AuthNKey, error)) {
func prepareAuthNKeyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) (*AuthNKey, error)) {
return sq.Select(
AuthNKeyColumnID.identifier(),
AuthNKeyColumnCreationDate.identifier(),
@@ -319,7 +321,8 @@ func prepareAuthNKeyQuery() (sq.SelectBuilder, func(row *sql.Row) (*AuthNKey, er
AuthNKeyColumnSequence.identifier(),
AuthNKeyColumnExpiration.identifier(),
AuthNKeyColumnType.identifier(),
).From(authNKeyTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*AuthNKey, error) {
authNKey := new(AuthNKey)
err := row.Scan(
@@ -341,10 +344,11 @@ func prepareAuthNKeyQuery() (sq.SelectBuilder, func(row *sql.Row) (*AuthNKey, er
}
}
func prepareAuthNKeyPublicKeyQuery() (sq.SelectBuilder, func(row *sql.Row) ([]byte, error)) {
func prepareAuthNKeyPublicKeyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) ([]byte, error)) {
return sq.Select(
AuthNKeyColumnPublicKey.identifier(),
).From(authNKeyTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) ([]byte, error) {
var publicKey []byte
err := row.Scan(
@@ -360,7 +364,7 @@ func prepareAuthNKeyPublicKeyQuery() (sq.SelectBuilder, func(row *sql.Row) ([]by
}
}
func prepareAuthNKeysDataQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeysData, error)) {
func prepareAuthNKeysDataQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeysData, error)) {
return sq.Select(
AuthNKeyColumnID.identifier(),
AuthNKeyColumnCreationDate.identifier(),
@@ -372,7 +376,8 @@ func prepareAuthNKeysDataQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthN
AuthNKeyColumnIdentifier.identifier(),
AuthNKeyColumnPublicKey.identifier(),
countColumn.identifier(),
).From(authNKeyTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*AuthNKeysData, error) {
authNKeys := make([]*AuthNKeyData, 0)
var count uint64

View File

@@ -12,6 +12,80 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
prepareAuthNKeysStmt = `SELECT projections.authn_keys2.id,` +
` projections.authn_keys2.creation_date,` +
` projections.authn_keys2.change_date,` +
` projections.authn_keys2.resource_owner,` +
` projections.authn_keys2.sequence,` +
` projections.authn_keys2.expiration,` +
` projections.authn_keys2.type,` +
` COUNT(*) OVER ()` +
` FROM projections.authn_keys2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareAuthNKeysCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
"count",
}
prepareAuthNKeysDataStmt = `SELECT projections.authn_keys2.id,` +
` projections.authn_keys2.creation_date,` +
` projections.authn_keys2.change_date,` +
` projections.authn_keys2.resource_owner,` +
` projections.authn_keys2.sequence,` +
` projections.authn_keys2.expiration,` +
` projections.authn_keys2.type,` +
` projections.authn_keys2.identifier,` +
` projections.authn_keys2.public_key,` +
` COUNT(*) OVER ()` +
` FROM projections.authn_keys2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareAuthNKeysDataCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
"identifier",
"public_key",
"count",
}
prepareAuthNKeyStmt = `SELECT projections.authn_keys2.id,` +
` projections.authn_keys2.creation_date,` +
` projections.authn_keys2.change_date,` +
` projections.authn_keys2.resource_owner,` +
` projections.authn_keys2.sequence,` +
` projections.authn_keys2.expiration,` +
` projections.authn_keys2.type` +
` FROM projections.authn_keys2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareAuthNKeyCols = []string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
}
prepareAuthNKeyPublicKeyStmt = `SELECT projections.authn_keys2.public_key` +
` FROM projections.authn_keys2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareAuthNKeyPublicKeyCols = []string{
"public_key",
}
)
func Test_AuthNKeyPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -28,15 +102,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeysStmt),
nil,
nil,
),
@@ -48,25 +114,8 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
"count",
},
regexp.QuoteMeta(prepareAuthNKeysStmt),
prepareAuthNKeysCols,
[][]driver.Value{
{
"id",
@@ -102,25 +151,8 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
"count",
},
regexp.QuoteMeta(prepareAuthNKeysStmt),
prepareAuthNKeysCols,
[][]driver.Value{
{
"id-1",
@@ -174,15 +206,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeysStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -199,17 +223,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysDataQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` projections.authn_keys2.identifier,`+
` projections.authn_keys2.public_key,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeysDataStmt),
nil,
nil,
),
@@ -221,29 +235,8 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysDataQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` projections.authn_keys2.identifier,`+
` projections.authn_keys2.public_key,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
"identifier",
"public_key",
"count",
},
regexp.QuoteMeta(prepareAuthNKeysDataStmt),
prepareAuthNKeysDataCols,
[][]driver.Value{
{
"id",
@@ -283,29 +276,8 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysDataQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` projections.authn_keys2.identifier,`+
` projections.authn_keys2.public_key,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
"identifier",
"public_key",
"count",
},
regexp.QuoteMeta(prepareAuthNKeysDataStmt),
prepareAuthNKeysDataCols,
[][]driver.Value{
{
"id-1",
@@ -367,17 +339,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeysDataQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type,`+
` projections.authn_keys2.identifier,`+
` projections.authn_keys2.public_key,`+
` COUNT(*) OVER ()`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeysDataStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -394,14 +356,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeyQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeyStmt),
nil,
nil,
),
@@ -419,23 +374,8 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeyQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type`+
` FROM projections.authn_keys2`),
[]string{
"id",
"creation_date",
"change_date",
"resource_owner",
"sequence",
"expiration",
"type",
},
regexp.QuoteMeta(prepareAuthNKeyStmt),
prepareAuthNKeyCols,
[]driver.Value{
"id",
testNow,
@@ -462,14 +402,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeyQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.authn_keys2.id,`+
` projections.authn_keys2.creation_date,`+
` projections.authn_keys2.change_date,`+
` projections.authn_keys2.resource_owner,`+
` projections.authn_keys2.sequence,`+
` projections.authn_keys2.expiration,`+
` projections.authn_keys2.type`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeyStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -486,8 +419,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeyPublicKeyQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.authn_keys2.public_key`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeyPublicKeyStmt),
nil,
nil,
),
@@ -505,11 +437,8 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeyPublicKeyQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.authn_keys2.public_key`+
` FROM projections.authn_keys2`),
[]string{
"public_key",
},
regexp.QuoteMeta(prepareAuthNKeyPublicKeyStmt),
prepareAuthNKeyPublicKeyCols,
[]driver.Value{
[]byte("publicKey"),
},
@@ -522,8 +451,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
prepare: prepareAuthNKeyPublicKeyQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.authn_keys2.public_key`+
` FROM projections.authn_keys2`),
regexp.QuoteMeta(prepareAuthNKeyPublicKeyStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -538,7 +466,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -8,6 +8,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
@@ -69,7 +70,7 @@ func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage dom
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareCertificateQuery()
query, scan := prepareCertificateQuery(ctx, q.client)
if t.IsZero() {
t = time.Now()
}
@@ -102,7 +103,7 @@ func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage dom
return keys, nil
}
func prepareCertificateQuery() (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) {
func prepareCertificateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) {
return sq.Select(
KeyColID.identifier(),
KeyColCreationDate.identifier(),
@@ -117,7 +118,7 @@ func prepareCertificateQuery() (sq.SelectBuilder, func(*sql.Rows) (*Certificates
countColumn.identifier(),
).From(keyTable.identifier()).
LeftJoin(join(CertificateColID, KeyColID)).
LeftJoin(join(KeyPrivateColID, KeyColID)).
LeftJoin(join(KeyPrivateColID, KeyColID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Certificates, error) {
certificates := make([]Certificate, 0)

View File

@@ -13,6 +13,37 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
prepareCertificateStmt = `SELECT projections.keys4.id,` +
` projections.keys4.creation_date,` +
` projections.keys4.change_date,` +
` projections.keys4.sequence,` +
` projections.keys4.resource_owner,` +
` projections.keys4.algorithm,` +
` projections.keys4.use,` +
` projections.keys4_certificate.expiry,` +
` projections.keys4_certificate.certificate,` +
` projections.keys4_private.key,` +
` COUNT(*) OVER ()` +
` FROM projections.keys4` +
` LEFT JOIN projections.keys4_certificate ON projections.keys4.id = projections.keys4_certificate.id AND projections.keys4.instance_id = projections.keys4_certificate.instance_id` +
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`
prepareCertificateCols = []string{
"id",
"creation_date",
"change_date",
"sequence",
"resource_owner",
"algorithm",
"use",
"expiry",
"certificate",
"key",
"count",
}
)
func Test_CertificatePrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -29,20 +60,7 @@ func Test_CertificatePrepares(t *testing.T) {
prepare: prepareCertificateQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_certificate.expiry,`+
` projections.keys4_certificate.certificate,`+
` projections.keys4_private.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_certificate ON projections.keys4.id = projections.keys4_certificate.id AND projections.keys4.instance_id = projections.keys4_certificate.instance_id`+
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id`),
regexp.QuoteMeta(prepareCertificateStmt),
nil,
nil,
),
@@ -60,33 +78,8 @@ func Test_CertificatePrepares(t *testing.T) {
prepare: prepareCertificateQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_certificate.expiry,`+
` projections.keys4_certificate.certificate,`+
` projections.keys4_private.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_certificate ON projections.keys4.id = projections.keys4_certificate.id AND projections.keys4.instance_id = projections.keys4_certificate.instance_id`+
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id`),
[]string{
"id",
"creation_date",
"change_date",
"sequence",
"resource_owner",
"algorithm",
"use",
"expiry",
"certificate",
"key",
"count",
},
regexp.QuoteMeta(prepareCertificateStmt),
prepareCertificateCols,
[][]driver.Value{
{
"key-id",
@@ -135,20 +128,7 @@ func Test_CertificatePrepares(t *testing.T) {
prepare: prepareCertificateQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_certificate.expiry,`+
` projections.keys4_certificate.certificate,`+
` projections.keys4_private.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_certificate ON projections.keys4.id = projections.keys4_certificate.id AND projections.keys4.instance_id = projections.keys4_certificate.instance_id`+
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id`),
regexp.QuoteMeta(prepareCertificateStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -163,7 +143,7 @@ func Test_CertificatePrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -93,7 +93,7 @@ func (q *Queries) UserChanges(ctx context.Context, userID string, lastSequence u
}
func (q *Queries) changes(ctx context.Context, query func(query *eventstore.SearchQuery), lastSequence uint64, limit uint64, sortAscending bool, auditLogRetention time.Duration) (*Changes, error) {
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(limit)
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(limit).AllowTimeTravel()
if !sortAscending {
builder.OrderDesc()
}

View File

@@ -10,6 +10,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@@ -59,7 +60,7 @@ func (q *Queries) SearchCurrentSequences(ctx context.Context, queries *CurrentSe
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareCurrentSequencesQuery()
query, scan := prepareCurrentSequencesQuery(ctx, q.client)
stmt, args, err := queries.toQuery(query).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-MmFef", "Errors.Query.InvalidRequest")
@@ -76,7 +77,7 @@ func (q *Queries) latestSequence(ctx context.Context, projections ...table) (_ *
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLatestSequence()
query, scan := prepareLatestSequence(ctx, q.client)
or := make(sq.Or, len(projections))
for i, projection := range projections {
or[i] = sq.Eq{CurrentSequenceColProjectionName.identifier(): projection.name}
@@ -201,11 +202,12 @@ func reset(tx *sql.Tx, tables []string, projectionName string) error {
return nil
}
func prepareLatestSequence() (sq.SelectBuilder, func(*sql.Row) (*LatestSequence, error)) {
func prepareLatestSequence(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LatestSequence, error)) {
return sq.Select(
CurrentSequenceColCurrentSequence.identifier(),
CurrentSequenceColTimestamp.identifier()).
From(currentSequencesTable.identifier()).PlaceholderFormat(sq.Dollar),
From(currentSequencesTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*LatestSequence, error) {
seq := new(LatestSequence)
err := row.Scan(
@@ -219,13 +221,13 @@ func prepareLatestSequence() (sq.SelectBuilder, func(*sql.Row) (*LatestSequence,
}
}
func prepareCurrentSequencesQuery() (sq.SelectBuilder, func(*sql.Rows) (*CurrentSequences, error)) {
func prepareCurrentSequencesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*CurrentSequences, error)) {
return sq.Select(
"max("+CurrentSequenceColCurrentSequence.identifier()+") as "+CurrentSequenceColCurrentSequence.name,
"max("+CurrentSequenceColTimestamp.identifier()+") as "+CurrentSequenceColTimestamp.name,
CurrentSequenceColProjectionName.identifier(),
countColumn.identifier()).
From(currentSequencesTable.identifier()).
From(currentSequencesTable.identifier() + db.Timetravel(call.Took(ctx))).
GroupBy(CurrentSequenceColProjectionName.identifier()).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*CurrentSequences, error) {

View File

@@ -9,6 +9,23 @@ import (
"testing"
)
var (
currentSequenceStmt = `SELECT max(projections.current_sequences.current_sequence) as current_sequence,` +
` max(projections.current_sequences.timestamp) as timestamp,` +
` projections.current_sequences.projection_name,` +
` COUNT(*) OVER ()` +
` FROM projections.current_sequences` +
" AS OF SYSTEM TIME '-1 ms' " +
` GROUP BY projections.current_sequences.projection_name`
currentSequenceCols = []string{
"current_sequence",
"timestamp",
"projection_name",
"count",
}
)
func Test_CurrentSequencesPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -25,12 +42,7 @@ func Test_CurrentSequencesPrepares(t *testing.T) {
prepare: prepareCurrentSequencesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT max(projections.current_sequences.current_sequence) as current_sequence,`+
` max(projections.current_sequences.timestamp) as timestamp,`+
` projections.current_sequences.projection_name,`+
` COUNT(*) OVER ()`+
` FROM projections.current_sequences`+
` GROUP BY projections.current_sequences.projection_name`),
regexp.QuoteMeta(currentSequenceStmt),
nil,
nil,
),
@@ -42,18 +54,8 @@ func Test_CurrentSequencesPrepares(t *testing.T) {
prepare: prepareCurrentSequencesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT max(projections.current_sequences.current_sequence) as current_sequence,`+
` max(projections.current_sequences.timestamp) as timestamp,`+
` projections.current_sequences.projection_name,`+
` COUNT(*) OVER ()`+
` FROM projections.current_sequences`+
` GROUP BY projections.current_sequences.projection_name`),
[]string{
"current_sequence",
"timestamp",
"projection_name",
"count",
},
regexp.QuoteMeta(currentSequenceStmt),
currentSequenceCols,
[][]driver.Value{
{
uint64(20211108),
@@ -81,18 +83,8 @@ func Test_CurrentSequencesPrepares(t *testing.T) {
prepare: prepareCurrentSequencesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT max(projections.current_sequences.current_sequence) as current_sequence,`+
` max(projections.current_sequences.timestamp) as timestamp,`+
` projections.current_sequences.projection_name,`+
` COUNT(*) OVER ()`+
` FROM projections.current_sequences`+
` GROUP BY projections.current_sequences.projection_name`),
[]string{
"current_sequence",
"timestamp",
"projection_name",
"count",
},
regexp.QuoteMeta(currentSequenceStmt),
currentSequenceCols,
[][]driver.Value{
{
uint64(20211108),
@@ -130,12 +122,7 @@ func Test_CurrentSequencesPrepares(t *testing.T) {
prepare: prepareCurrentSequencesQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT max(projections.current_sequences.current_sequence) as current_sequence,`+
` max(projections.current_sequences.timestamp) as timestamp,`+
` projections.current_sequences.projection_name,`+
` COUNT(*) OVER ()`+
` FROM projections.current_sequences`+
` GROUP BY projections.current_sequences.projection_name`),
regexp.QuoteMeta(currentSequenceStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -150,7 +137,7 @@ func Test_CurrentSequencesPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -13,6 +13,7 @@ import (
"sigs.k8s.io/yaml"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -88,7 +89,7 @@ func (q *Queries) CustomTextList(ctx context.Context, aggregateID, template, lan
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareCustomTextsQuery()
stmt, scan := prepareCustomTextsQuery(ctx, q.client)
eq := sq.Eq{
CustomTextColAggregateID.identifier(): aggregateID,
CustomTextColTemplate.identifier(): template,
@@ -119,7 +120,7 @@ func (q *Queries) CustomTextListByTemplate(ctx context.Context, aggregateID, tem
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareCustomTextsQuery()
stmt, scan := prepareCustomTextsQuery(ctx, q.client)
eq := sq.Eq{
CustomTextColAggregateID.identifier(): aggregateID,
CustomTextColTemplate.identifier(): template,
@@ -228,7 +229,7 @@ func (q *Queries) readLoginTranslationFile(ctx context.Context, lang string) ([]
return contents, nil
}
func prepareCustomTextsQuery() (sq.SelectBuilder, func(*sql.Rows) (*CustomTexts, error)) {
func prepareCustomTextsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*CustomTexts, error)) {
return sq.Select(
CustomTextColAggregateID.identifier(),
CustomTextColSequence.identifier(),
@@ -239,7 +240,8 @@ func prepareCustomTextsQuery() (sq.SelectBuilder, func(*sql.Rows) (*CustomTexts,
CustomTextColKey.identifier(),
CustomTextColText.identifier(),
countColumn.identifier()).
From(customTextTable.identifier()).PlaceholderFormat(sq.Dollar),
From(customTextTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*CustomTexts, error) {
customTexts := make([]*CustomText, 0)
var count uint64

View File

@@ -13,6 +13,31 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
prepareCustomTextsStmt = `SELECT projections.custom_texts2.aggregate_id,` +
` projections.custom_texts2.sequence,` +
` projections.custom_texts2.creation_date,` +
` projections.custom_texts2.change_date,` +
` projections.custom_texts2.language,` +
` projections.custom_texts2.template,` +
` projections.custom_texts2.key,` +
` projections.custom_texts2.text,` +
` COUNT(*) OVER ()` +
` FROM projections.custom_texts2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareCustomTextsCols = []string{
"aggregate_id",
"sequence",
"creation_date",
"change_date",
"language",
"template",
"key",
"text",
"count",
}
)
func Test_CustomTextPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -29,16 +54,7 @@ func Test_CustomTextPrepares(t *testing.T) {
prepare: prepareCustomTextsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.custom_texts2.aggregate_id,`+
` projections.custom_texts2.sequence,`+
` projections.custom_texts2.creation_date,`+
` projections.custom_texts2.change_date,`+
` projections.custom_texts2.language,`+
` projections.custom_texts2.template,`+
` projections.custom_texts2.key,`+
` projections.custom_texts2.text,`+
` COUNT(*) OVER ()`+
` FROM projections.custom_texts2`),
regexp.QuoteMeta(prepareCustomTextsStmt),
nil,
nil,
),
@@ -56,27 +72,8 @@ func Test_CustomTextPrepares(t *testing.T) {
prepare: prepareCustomTextsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.custom_texts2.aggregate_id,`+
` projections.custom_texts2.sequence,`+
` projections.custom_texts2.creation_date,`+
` projections.custom_texts2.change_date,`+
` projections.custom_texts2.language,`+
` projections.custom_texts2.template,`+
` projections.custom_texts2.key,`+
` projections.custom_texts2.text,`+
` COUNT(*) OVER ()`+
` FROM projections.custom_texts2`),
[]string{
"aggregate_id",
"sequence",
"creation_date",
"change_date",
"language",
"template",
"key",
"text",
"count",
},
regexp.QuoteMeta(prepareCustomTextsStmt),
prepareCustomTextsCols,
[][]driver.Value{
{
"agg-id",
@@ -114,27 +111,8 @@ func Test_CustomTextPrepares(t *testing.T) {
prepare: prepareCustomTextsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.custom_texts2.aggregate_id,`+
` projections.custom_texts2.sequence,`+
` projections.custom_texts2.creation_date,`+
` projections.custom_texts2.change_date,`+
` projections.custom_texts2.language,`+
` projections.custom_texts2.template,`+
` projections.custom_texts2.key,`+
` projections.custom_texts2.text,`+
` COUNT(*) OVER ()`+
` FROM projections.custom_texts2`),
[]string{
"aggregate_id",
"sequence",
"creation_date",
"change_date",
"language",
"template",
"key",
"text",
"count",
},
regexp.QuoteMeta(prepareCustomTextsStmt),
prepareCustomTextsCols,
[][]driver.Value{
{
"agg-id",
@@ -192,16 +170,7 @@ func Test_CustomTextPrepares(t *testing.T) {
prepare: prepareCustomTextsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.custom_texts2.aggregate_id,`+
` projections.custom_texts2.sequence,`+
` projections.custom_texts2.creation_date,`+
` projections.custom_texts2.change_date,`+
` projections.custom_texts2.language,`+
` projections.custom_texts2.template,`+
` projections.custom_texts2.key,`+
` projections.custom_texts2.text,`+
` COUNT(*) OVER ()`+
` FROM projections.custom_texts2`),
regexp.QuoteMeta(prepareCustomTextsStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -216,7 +185,7 @@ func Test_CustomTextPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -112,7 +113,7 @@ func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool,
}
}
stmt, scan := prepareDomainPolicyQuery()
stmt, scan := prepareDomainPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(eq).OrderBy(DomainPolicyColIsDefault.identifier()).
Limit(1).ToSql()
if err != nil {
@@ -127,7 +128,7 @@ func (q *Queries) DefaultDomainPolicy(ctx context.Context) (_ *DomainPolicy, err
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareDomainPolicyQuery()
stmt, scan := prepareDomainPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
DomainPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
DomainPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -142,7 +143,7 @@ func (q *Queries) DefaultDomainPolicy(ctx context.Context) (_ *DomainPolicy, err
return scan(row)
}
func prepareDomainPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) {
func prepareDomainPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) {
return sq.Select(
DomainPolicyColID.identifier(),
DomainPolicyColSequence.identifier(),
@@ -155,7 +156,8 @@ func prepareDomainPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy
DomainPolicyColIsDefault.identifier(),
DomainPolicyColState.identifier(),
).
From(domainPolicyTable.identifier()).PlaceholderFormat(sq.Dollar),
From(domainPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*DomainPolicy, error) {
policy := new(DomainPolicy)
err := row.Scan(

View File

@@ -12,6 +12,33 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
prepareDomainPolicyStmt = `SELECT projections.domain_policies2.id,` +
` projections.domain_policies2.sequence,` +
` projections.domain_policies2.creation_date,` +
` projections.domain_policies2.change_date,` +
` projections.domain_policies2.resource_owner,` +
` projections.domain_policies2.user_login_must_be_domain,` +
` projections.domain_policies2.validate_org_domains,` +
` projections.domain_policies2.smtp_sender_address_matches_instance_domain,` +
` projections.domain_policies2.is_default,` +
` projections.domain_policies2.state` +
` FROM projections.domain_policies2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareDomainPolicyCols = []string{
"id",
"sequence",
"creation_date",
"change_date",
"resource_owner",
"user_login_must_be_domain",
"validate_org_domains",
"smtp_sender_address_matches_instance_domain",
"is_default",
"state",
}
)
func Test_DomainPolicyPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -28,17 +55,7 @@ func Test_DomainPolicyPrepares(t *testing.T) {
prepare: prepareDomainPolicyQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.domain_policies2.id,`+
` projections.domain_policies2.sequence,`+
` projections.domain_policies2.creation_date,`+
` projections.domain_policies2.change_date,`+
` projections.domain_policies2.resource_owner,`+
` projections.domain_policies2.user_login_must_be_domain,`+
` projections.domain_policies2.validate_org_domains,`+
` projections.domain_policies2.smtp_sender_address_matches_instance_domain,`+
` projections.domain_policies2.is_default,`+
` projections.domain_policies2.state`+
` FROM projections.domain_policies2`),
regexp.QuoteMeta(prepareDomainPolicyStmt),
nil,
nil,
),
@@ -56,29 +73,8 @@ func Test_DomainPolicyPrepares(t *testing.T) {
prepare: prepareDomainPolicyQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.domain_policies2.id,`+
` projections.domain_policies2.sequence,`+
` projections.domain_policies2.creation_date,`+
` projections.domain_policies2.change_date,`+
` projections.domain_policies2.resource_owner,`+
` projections.domain_policies2.user_login_must_be_domain,`+
` projections.domain_policies2.validate_org_domains,`+
` projections.domain_policies2.smtp_sender_address_matches_instance_domain,`+
` projections.domain_policies2.is_default,`+
` projections.domain_policies2.state`+
` FROM projections.domain_policies2`),
[]string{
"id",
"sequence",
"creation_date",
"change_date",
"resource_owner",
"user_login_must_be_domain",
"validate_org_domains",
"smtp_sender_address_matches_instance_domain",
"is_default",
"state",
},
regexp.QuoteMeta(prepareDomainPolicyStmt),
prepareDomainPolicyCols,
[]driver.Value{
"pol-id",
uint64(20211109),
@@ -111,17 +107,7 @@ func Test_DomainPolicyPrepares(t *testing.T) {
prepare: prepareDomainPolicyQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.domain_policies2.id,`+
` projections.domain_policies2.sequence,`+
` projections.domain_policies2.creation_date,`+
` projections.domain_policies2.change_date,`+
` projections.domain_policies2.resource_owner,`+
` projections.domain_policies2.user_login_must_be_domain,`+
` projections.domain_policies2.validate_org_domains,`+
` projections.domain_policies2.smtp_sender_address_matches_instance_domain,`+
` projections.domain_policies2.is_default,`+
` projections.domain_policies2.state`+
` FROM projections.domain_policies2`),
regexp.QuoteMeta(prepareDomainPolicyStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -136,7 +122,7 @@ func Test_DomainPolicyPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -26,7 +26,7 @@ type EventEditor struct {
func (q *Queries) SearchEvents(ctx context.Context, query *eventstore.SearchQueryBuilder) (_ []*Event, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
events, err := q.eventstore.Filter(ctx, query)
events, err := q.eventstore.Filter(ctx, query.AllowTimeTravel())
if err != nil {
return nil, err
}

View File

@@ -7,6 +7,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
)
@@ -70,7 +71,7 @@ type FailedEventSearchQueries struct {
}
func (q *Queries) SearchFailedEvents(ctx context.Context, queries *FailedEventSearchQueries) (failedEvents *FailedEvents, err error) {
query, scan := prepareFailedEventsQuery()
query, scan := prepareFailedEventsQuery(ctx, q.client)
stmt, args, err := queries.toQuery(query).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-n8rjJ", "Errors.Query.InvalidRequest")
@@ -123,7 +124,7 @@ func (q *FailedEventSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuil
return query
}
func prepareFailedEventsQuery() (sq.SelectBuilder, func(*sql.Rows) (*FailedEvents, error)) {
func prepareFailedEventsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*FailedEvents, error)) {
return sq.Select(
FailedEventsColumnProjectionName.identifier(),
FailedEventsColumnFailedSequence.identifier(),
@@ -131,7 +132,8 @@ func prepareFailedEventsQuery() (sq.SelectBuilder, func(*sql.Rows) (*FailedEvent
FailedEventsColumnLastFailed.identifier(),
FailedEventsColumnError.identifier(),
countColumn.identifier()).
From(failedEventsTable.identifier()).PlaceholderFormat(sq.Dollar),
From(failedEventsTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*FailedEvents, error) {
failedEvents := make([]*FailedEvent, 0)
var count uint64

View File

@@ -9,6 +9,26 @@ import (
"testing"
)
var (
prepareFailedEventsStmt = `SELECT projections.failed_events.projection_name,` +
` projections.failed_events.failed_sequence,` +
` projections.failed_events.failure_count,` +
` projections.failed_events.last_failed,` +
` projections.failed_events.error,` +
` COUNT(*) OVER ()` +
` FROM projections.failed_events` +
` AS OF SYSTEM TIME '-1 ms'`
prepareFailedEventsCols = []string{
"projection_name",
"failed_sequence",
"failure_count",
"last_failed",
"error",
"count",
}
)
func Test_FailedEventsPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -25,13 +45,7 @@ func Test_FailedEventsPrepares(t *testing.T) {
prepare: prepareFailedEventsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.failed_events.projection_name,`+
` projections.failed_events.failed_sequence,`+
` projections.failed_events.failure_count,`+
` projections.failed_events.last_failed,`+
` projections.failed_events.error,`+
` COUNT(*) OVER ()`+
` FROM projections.failed_events`),
regexp.QuoteMeta(prepareFailedEventsStmt),
nil,
nil,
),
@@ -43,21 +57,8 @@ func Test_FailedEventsPrepares(t *testing.T) {
prepare: prepareFailedEventsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.failed_events.projection_name,`+
` projections.failed_events.failed_sequence,`+
` projections.failed_events.failure_count,`+
` projections.failed_events.last_failed,`+
` projections.failed_events.error,`+
` COUNT(*) OVER ()`+
` FROM projections.failed_events`),
[]string{
"projection_name",
"failed_sequence",
"failure_count",
"last_failed",
"error",
"count",
},
regexp.QuoteMeta(prepareFailedEventsStmt),
prepareFailedEventsCols,
[][]driver.Value{
{
"projection-name",
@@ -89,21 +90,8 @@ func Test_FailedEventsPrepares(t *testing.T) {
prepare: prepareFailedEventsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.failed_events.projection_name,`+
` projections.failed_events.failed_sequence,`+
` projections.failed_events.failure_count,`+
` projections.failed_events.last_failed,`+
` projections.failed_events.error,`+
` COUNT(*) OVER ()`+
` FROM projections.failed_events`),
[]string{
"projection_name",
"failed_sequence",
"failure_count",
"last_failed",
"error",
"count",
},
regexp.QuoteMeta(prepareFailedEventsStmt),
prepareFailedEventsCols,
[][]driver.Value{
{
"projection-name",
@@ -148,13 +136,7 @@ func Test_FailedEventsPrepares(t *testing.T) {
prepare: prepareFailedEventsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.failed_events.projection_name,`+
` projections.failed_events.failed_sequence,`+
` projections.failed_events.failure_count,`+
` projections.failed_events.last_failed,`+
` projections.failed_events.error,`+
` COUNT(*) OVER ()`+
` FROM projections.failed_events`),
regexp.QuoteMeta(prepareFailedEventsStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -169,7 +151,7 @@ func Test_FailedEventsPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -7,6 +7,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@@ -78,7 +79,7 @@ func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, with
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareInstanceMembersQuery()
query, scan := prepareInstanceMembersQuery(ctx, q.client)
eq := sq.Eq{InstanceMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
addIamMemberWithoutOwnerRemoved(eq)
@@ -106,7 +107,7 @@ func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, with
return members, err
}
func prepareInstanceMembersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) {
func prepareInstanceMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) {
return sq.Select(
InstanceMemberCreationDate.identifier(),
InstanceMemberChangeDate.identifier(),
@@ -125,7 +126,7 @@ func prepareInstanceMembersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Members,
).From(instanceMemberTable.identifier()).
LeftJoin(join(HumanUserIDCol, InstanceMemberUserID)).
LeftJoin(join(MachineUserIDCol, InstanceMemberUserID)).
LeftJoin(join(LoginNameUserIDCol, InstanceMemberUserID)).
LeftJoin(join(LoginNameUserIDCol, InstanceMemberUserID) + db.Timetravel(call.Took(ctx))).
Where(
sq.Eq{LoginNameIsPrimaryCol.identifier(): true},
).PlaceholderFormat(sq.Dollar),

View File

@@ -34,6 +34,7 @@ var (
"ON members.user_id = projections.users8_machines.user_id AND members.instance_id = projections.users8_machines.instance_id " +
"LEFT JOIN projections.login_names2 " +
"ON members.user_id = projections.login_names2.user_id AND members.instance_id = projections.login_names2.instance_id " +
"AS OF SYSTEM TIME '-1 ms' " +
"WHERE projections.login_names2.is_primary = $1")
instanceMembersColumns = []string{
"creation_date",
@@ -271,7 +272,7 @@ func Test_IAMMemberPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
@@ -209,7 +210,7 @@ func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk
sq.Eq{IDPResourceOwnerCol.identifier(): authz.GetInstance(ctx).InstanceID()},
},
}
stmt, scan := prepareIDPByIDQuery()
stmt, scan := prepareIDPByIDQuery(ctx, q.client)
query, args, err := stmt.Where(where).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-0gocI", "Errors.Query.SQLStatement")
@@ -224,7 +225,7 @@ func (q *Queries) IDPs(ctx context.Context, queries *IDPSearchQueries, withOwner
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareIDPsQuery()
query, scan := prepareIDPsQuery(ctx, q.client)
eq := sq.Eq{
IDPInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
@@ -285,7 +286,7 @@ func (q *IDPSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query
}
func prepareIDPByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDP, error)) {
func prepareIDPByIDQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*IDP, error)) {
return sq.Select(
IDPIDCol.identifier(),
IDPResourceOwnerCol.identifier(),
@@ -313,7 +314,7 @@ func prepareIDPByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDP, error)) {
JWTIDPColEndpoint.identifier(),
).From(idpTable.identifier()).
LeftJoin(join(OIDCIDPColIDPID, IDPIDCol)).
LeftJoin(join(JWTIDPColIDPID, IDPIDCol)).
LeftJoin(join(JWTIDPColIDPID, IDPIDCol) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*IDP, error) {
idp := new(IDP)
@@ -393,7 +394,7 @@ func prepareIDPByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDP, error)) {
}
}
func prepareIDPsQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPs, error)) {
func prepareIDPsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPs, error)) {
return sq.Select(
IDPIDCol.identifier(),
IDPResourceOwnerCol.identifier(),
@@ -422,7 +423,7 @@ func prepareIDPsQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPs, error)) {
countColumn.identifier(),
).From(idpTable.identifier()).
LeftJoin(join(OIDCIDPColIDPID, IDPIDCol)).
LeftJoin(join(JWTIDPColIDPID, IDPIDCol)).
LeftJoin(join(JWTIDPColIDPID, IDPIDCol) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*IDPs, error) {
idps := make([]*IDP, 0)

View File

@@ -7,6 +7,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -84,7 +85,7 @@ func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string,
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareIDPLoginPolicyLinksQuery()
query, scan := prepareIDPLoginPolicyLinksQuery(ctx, q.client)
eq := sq.Eq{
IDPLoginPolicyLinkResourceOwnerCol.identifier(): resourceOwner,
IDPLoginPolicyLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -109,14 +110,15 @@ func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string,
return idps, err
}
func prepareIDPLoginPolicyLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) {
func prepareIDPLoginPolicyLinksQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) {
return sq.Select(
IDPLoginPolicyLinkIDPIDCol.identifier(),
IDPNameCol.identifier(),
IDPTypeCol.identifier(),
countColumn.identifier()).
From(idpLoginPolicyLinkTable.identifier()).
LeftJoin(join(IDPIDCol, IDPLoginPolicyLinkIDPIDCol)).PlaceholderFormat(sq.Dollar),
LeftJoin(join(IDPIDCol, IDPLoginPolicyLinkIDPIDCol) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*IDPLoginPolicyLinks, error) {
links := make([]*IDPLoginPolicyLink, 0)
var count uint64

View File

@@ -17,7 +17,8 @@ var (
` projections.idps3.type,` +
` COUNT(*) OVER ()` +
` FROM projections.idp_login_policy_links4` +
` LEFT JOIN projections.idps3 ON projections.idp_login_policy_links4.idp_id = projections.idps3.id`)
` LEFT JOIN projections.idps3 ON projections.idp_login_policy_links4.idp_id = projections.idps3.id AND projections.idp_login_policy_links4.instance_id = projections.idps3.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
loginPolicyIDPLinksCols = []string{
"idp_id",
"name",
@@ -115,7 +116,7 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
@@ -409,7 +410,7 @@ func (q *Queries) IDPTemplateByIDAndResourceOwner(ctx context.Context, shouldTri
sq.Eq{IDPTemplateResourceOwnerCol.identifier(): authz.GetInstance(ctx).InstanceID()},
},
}
stmt, scan := prepareIDPTemplateByIDQuery()
stmt, scan := prepareIDPTemplateByIDQuery(ctx, q.client)
query, args, err := stmt.Where(where).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SFAew", "Errors.Query.SQLStatement")
@@ -424,7 +425,7 @@ func (q *Queries) IDPTemplates(ctx context.Context, queries *IDPTemplateSearchQu
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareIDPTemplatesQuery()
query, scan := prepareIDPTemplatesQuery(ctx, q.client)
eq := sq.Eq{
IDPTemplateInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
@@ -485,7 +486,7 @@ func (q *IDPTemplateSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuil
return query
}
func prepareIDPTemplateByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDPTemplate, error)) {
func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*IDPTemplate, error)) {
return sq.Select(
IDPTemplateIDCol.identifier(),
IDPTemplateResourceOwnerCol.identifier(),
@@ -553,7 +554,7 @@ func prepareIDPTemplateByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDPTempla
LeftJoin(join(OIDCIDCol, IDPTemplateIDCol)).
LeftJoin(join(JWTIDCol, IDPTemplateIDCol)).
LeftJoin(join(GoogleIDCol, IDPTemplateIDCol)).
LeftJoin(join(LDAPIDCol, IDPTemplateIDCol)).
LeftJoin(join(LDAPIDCol, IDPTemplateIDCol) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*IDPTemplate, error) {
idpTemplate := new(IDPTemplate)
@@ -750,7 +751,7 @@ func prepareIDPTemplateByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDPTempla
}
}
func prepareIDPTemplatesQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPTemplates, error)) {
func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPTemplates, error)) {
return sq.Select(
IDPTemplateIDCol.identifier(),
IDPTemplateResourceOwnerCol.identifier(),
@@ -819,7 +820,7 @@ func prepareIDPTemplatesQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPTemplate
LeftJoin(join(OIDCIDCol, IDPTemplateIDCol)).
LeftJoin(join(JWTIDCol, IDPTemplateIDCol)).
LeftJoin(join(GoogleIDCol, IDPTemplateIDCol)).
LeftJoin(join(LDAPIDCol, IDPTemplateIDCol)).
LeftJoin(join(LDAPIDCol, IDPTemplateIDCol) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*IDPTemplates, error) {
templates := make([]*IDPTemplate, 0)

View File

@@ -81,7 +81,8 @@ var (
` LEFT JOIN projections.idp_templates2_oidc ON projections.idp_templates2.id = projections.idp_templates2_oidc.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oidc.instance_id` +
` LEFT JOIN projections.idp_templates2_jwt ON projections.idp_templates2.id = projections.idp_templates2_jwt.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_jwt.instance_id` +
` LEFT JOIN projections.idp_templates2_google ON projections.idp_templates2.id = projections.idp_templates2_google.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_google.instance_id` +
` LEFT JOIN projections.idp_templates2_ldap ON projections.idp_templates2.id = projections.idp_templates2_ldap.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_ldap.instance_id`
` LEFT JOIN projections.idp_templates2_ldap ON projections.idp_templates2.id = projections.idp_templates2_ldap.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_ldap.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`
idpTemplateCols = []string{
"id",
"resource_owner",
@@ -212,7 +213,8 @@ var (
` LEFT JOIN projections.idp_templates2_oidc ON projections.idp_templates2.id = projections.idp_templates2_oidc.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oidc.instance_id` +
` LEFT JOIN projections.idp_templates2_jwt ON projections.idp_templates2.id = projections.idp_templates2_jwt.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_jwt.instance_id` +
` LEFT JOIN projections.idp_templates2_google ON projections.idp_templates2.id = projections.idp_templates2_google.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_google.instance_id` +
` LEFT JOIN projections.idp_templates2_ldap ON projections.idp_templates2.id = projections.idp_templates2_ldap.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_ldap.instance_id`
` LEFT JOIN projections.idp_templates2_ldap ON projections.idp_templates2.id = projections.idp_templates2_ldap.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_ldap.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`
idpTemplatesCols = []string{
"id",
"resource_owner",
@@ -1628,7 +1630,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -733,7 +733,7 @@ func Test_IDPPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -7,6 +7,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -92,7 +93,7 @@ func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareIDPUserLinksQuery()
query, scan := prepareIDPUserLinksQuery(ctx, q.client)
eq := sq.Eq{IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[IDPUserLinkOwnerRemovedCol.identifier()] = false
@@ -126,7 +127,7 @@ func NewIDPUserLinksResourceOwnerSearchQuery(value string) (SearchQuery, error)
return NewTextQuery(IDPUserLinkResourceOwnerCol, value, TextEquals)
}
func prepareIDPUserLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPUserLinks, error)) {
func prepareIDPUserLinksQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPUserLinks, error)) {
return sq.Select(
IDPUserLinkIDPIDCol.identifier(),
IDPUserLinkUserIDCol.identifier(),
@@ -137,7 +138,8 @@ func prepareIDPUserLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPUserLink
IDPUserLinkResourceOwnerCol.identifier(),
countColumn.identifier()).
From(idpUserLinkTable.identifier()).
LeftJoin(join(IDPIDCol, IDPUserLinkIDPIDCol)).PlaceholderFormat(sq.Dollar),
LeftJoin(join(IDPIDCol, IDPUserLinkIDPIDCol) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*IDPUserLinks, error) {
idps := make([]*IDPUserLink, 0)
var count uint64

View File

@@ -21,7 +21,8 @@ var (
` projections.idp_user_links3.resource_owner,` +
` COUNT(*) OVER ()` +
` FROM projections.idp_user_links3` +
` LEFT JOIN projections.idps3 ON projections.idp_user_links3.idp_id = projections.idps3.id`)
` LEFT JOIN projections.idps3 ON projections.idp_user_links3.idp_id = projections.idps3.id AND projections.idp_user_links3.instance_id = projections.idps3.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
idpUserLinksCols = []string{
"idp_id",
"user_id",
@@ -139,7 +140,7 @@ func Test_IDPUserLinkPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -159,7 +160,7 @@ func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQu
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
filter, query, scan := prepareInstancesQuery()
filter, query, scan := prepareInstancesQuery(ctx, q.client)
stmt, args, err := query(queries.toQuery(filter)).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement")
@@ -184,7 +185,7 @@ func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (_ *Inst
projection.InstanceProjection.Trigger(ctx)
}
stmt, scan := prepareInstanceDomainQuery(authz.GetInstance(ctx).RequestedDomain())
stmt, scan := prepareInstanceDomainQuery(ctx, q.client, authz.GetInstance(ctx).RequestedDomain())
query, args, err := stmt.Where(sq.Eq{
InstanceColumnID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
@@ -203,7 +204,7 @@ func (q *Queries) InstanceByHost(ctx context.Context, host string) (_ authz.Inst
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareAuthzInstanceQuery(host)
stmt, scan := prepareAuthzInstanceQuery(ctx, q.client, host)
host = strings.Split(host, ":")[0] //remove possible port
query, args, err := stmt.Where(sq.Eq{
InstanceDomainDomainCol.identifier(): host,
@@ -231,7 +232,7 @@ func (q *Queries) GetDefaultLanguage(ctx context.Context) language.Tag {
return instance.DefaultLanguage()
}
func prepareInstanceQuery(host string) (sq.SelectBuilder, func(*sql.Row) (*Instance, error)) {
func prepareInstanceQuery(ctx context.Context, db prepareDatabase, host string) (sq.SelectBuilder, func(*sql.Row) (*Instance, error)) {
return sq.Select(
InstanceColumnID.identifier(),
InstanceColumnCreationDate.identifier(),
@@ -243,7 +244,8 @@ func prepareInstanceQuery(host string) (sq.SelectBuilder, func(*sql.Row) (*Insta
InstanceColumnConsoleAppID.identifier(),
InstanceColumnDefaultLanguage.identifier(),
).
From(instanceTable.identifier()).PlaceholderFormat(sq.Dollar),
From(instanceTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*Instance, error) {
instance := &Instance{host: host}
lang := ""
@@ -269,7 +271,7 @@ func prepareInstanceQuery(host string) (sq.SelectBuilder, func(*sql.Row) (*Insta
}
}
func prepareInstancesQuery() (sq.SelectBuilder, func(sq.SelectBuilder) sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
func prepareInstancesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(sq.SelectBuilder) sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
instanceFilterTable := instanceTable.setAlias(InstancesFilterTableAlias)
instanceFilterIDColumn := InstanceColumnID.setTable(instanceFilterTable)
instanceFilterCountColumn := InstancesFilterTableAlias + ".count"
@@ -298,7 +300,7 @@ func prepareInstancesQuery() (sq.SelectBuilder, func(sq.SelectBuilder) sq.Select
InstanceDomainSequenceCol.identifier(),
).FromSelect(builder, InstancesFilterTableAlias).
LeftJoin(join(InstanceColumnID, instanceFilterIDColumn)).
LeftJoin(join(InstanceDomainInstanceIDCol, instanceFilterIDColumn)).
LeftJoin(join(InstanceDomainInstanceIDCol, instanceFilterIDColumn) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar)
},
func(rows *sql.Rows) (*Instances, error) {
@@ -373,7 +375,7 @@ func prepareInstancesQuery() (sq.SelectBuilder, func(sq.SelectBuilder) sq.Select
}
}
func prepareInstanceDomainQuery(host string) (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) {
func prepareInstanceDomainQuery(ctx context.Context, db prepareDatabase, host string) (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) {
return sq.Select(
InstanceColumnID.identifier(),
InstanceColumnCreationDate.identifier(),
@@ -393,7 +395,7 @@ func prepareInstanceDomainQuery(host string) (sq.SelectBuilder, func(*sql.Rows)
InstanceDomainSequenceCol.identifier(),
).
From(instanceTable.identifier()).
LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID)).
LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Instance, error) {
instance := &Instance{
@@ -455,7 +457,7 @@ func prepareInstanceDomainQuery(host string) (sq.SelectBuilder, func(*sql.Rows)
}
}
func prepareAuthzInstanceQuery(host string) (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) {
func prepareAuthzInstanceQuery(ctx context.Context, db prepareDatabase, host string) (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) {
return sq.Select(
InstanceColumnID.identifier(),
InstanceColumnCreationDate.identifier(),
@@ -478,7 +480,7 @@ func prepareAuthzInstanceQuery(host string) (sq.SelectBuilder, func(*sql.Rows) (
).
From(instanceTable.identifier()).
LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID)).
LeftJoin(join(SecurityPolicyColumnInstanceID, InstanceColumnID)).
LeftJoin(join(SecurityPolicyColumnInstanceID, InstanceColumnID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Instance, error) {
instance := &Instance{

View File

@@ -8,6 +8,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@@ -61,7 +62,7 @@ func (q *Queries) SearchInstanceDomains(ctx context.Context, queries *InstanceDo
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareInstanceDomainsQuery()
query, scan := prepareInstanceDomainsQuery(ctx, q.client)
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
InstanceDomainInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -77,7 +78,7 @@ func (q *Queries) SearchInstanceDomainsGlobal(ctx context.Context, queries *Inst
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareInstanceDomainsQuery()
query, scan := prepareInstanceDomainsQuery(ctx, q.client)
stmt, args, err := queries.toQuery(query).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-IHhLR", "Errors.Query.SQLStatement")
@@ -99,7 +100,7 @@ func (q *Queries) queryInstanceDomains(ctx context.Context, stmt string, scan fu
return domains, err
}
func prepareInstanceDomainsQuery() (sq.SelectBuilder, func(*sql.Rows) (*InstanceDomains, error)) {
func prepareInstanceDomainsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*InstanceDomains, error)) {
return sq.Select(
InstanceDomainCreationDateCol.identifier(),
InstanceDomainChangeDateCol.identifier(),
@@ -109,7 +110,8 @@ func prepareInstanceDomainsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Instance
InstanceDomainIsGeneratedCol.identifier(),
InstanceDomainIsPrimaryCol.identifier(),
countColumn.identifier(),
).From(instanceDomainsTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(instanceDomainsTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*InstanceDomains, error) {
domains := make([]*InstanceDomain, 0)
var count uint64

View File

@@ -9,6 +9,29 @@ import (
"testing"
)
var (
prepareInstanceDomainsStmt = `SELECT projections.instance_domains.creation_date,` +
` projections.instance_domains.change_date,` +
` projections.instance_domains.sequence,` +
` projections.instance_domains.domain,` +
` projections.instance_domains.instance_id,` +
` projections.instance_domains.is_generated,` +
` projections.instance_domains.is_primary,` +
` COUNT(*) OVER ()` +
` FROM projections.instance_domains` +
` AS OF SYSTEM TIME '-1 ms'`
prepareInstanceDomainsCols = []string{
"creation_date",
"change_date",
"sequence",
"domain",
"instance_id",
"is_generated",
"is_primary",
"count",
}
)
func Test_InstanceDomainPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -25,15 +48,7 @@ func Test_InstanceDomainPrepares(t *testing.T) {
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` projections.instance_domains.is_primary,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
regexp.QuoteMeta(prepareInstanceDomainsStmt),
nil,
nil,
),
@@ -45,25 +60,8 @@ func Test_InstanceDomainPrepares(t *testing.T) {
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` projections.instance_domains.is_primary,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
[]string{
"creation_date",
"change_date",
"sequence",
"domain",
"instance_id",
"is_generated",
"is_primary",
"count",
},
regexp.QuoteMeta(prepareInstanceDomainsStmt),
prepareInstanceDomainsCols,
[][]driver.Value{
{
testNow,
@@ -99,25 +97,8 @@ func Test_InstanceDomainPrepares(t *testing.T) {
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` projections.instance_domains.is_primary,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
[]string{
"creation_date",
"change_date",
"sequence",
"domain",
"instance_id",
"is_generated",
"is_primary",
"count",
},
regexp.QuoteMeta(prepareInstanceDomainsStmt),
prepareInstanceDomainsCols,
[][]driver.Value{
{
testNow,
@@ -171,15 +152,7 @@ func Test_InstanceDomainPrepares(t *testing.T) {
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` projections.instance_domains.is_primary,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
regexp.QuoteMeta(prepareInstanceDomainsStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -194,7 +167,7 @@ func Test_InstanceDomainPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -1,10 +1,12 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"regexp"
"testing"
@@ -24,7 +26,8 @@ var (
` projections.instances.console_client_id,` +
` projections.instances.console_app_id,` +
` projections.instances.default_language` +
` FROM projections.instances`
` FROM projections.instances` +
` AS OF SYSTEM TIME '-1 ms'`
instanceCols = []string{
"id",
"creation_date",
@@ -54,7 +57,8 @@ var (
` projections.instance_domains.sequence` +
` FROM (SELECT projections.instances.id, COUNT(*) OVER () FROM projections.instances) AS f` +
` LEFT JOIN projections.instances ON f.id = projections.instances.id` +
` LEFT JOIN projections.instance_domains ON f.id = projections.instance_domains.instance_id`
` LEFT JOIN projections.instance_domains ON f.id = projections.instance_domains.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`
instancesCols = []string{
"count",
"id",
@@ -82,16 +86,16 @@ func Test_InstancePrepares(t *testing.T) {
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
name string
prepare interface{}
additionalArgs []reflect.Value
want want
object interface{}
}{
{
name: "prepareInstanceQuery no result",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*Instance, error)) {
return prepareInstanceQuery("")
},
name: "prepareInstanceQuery no result",
additionalArgs: []reflect.Value{reflect.ValueOf("")},
prepare: prepareInstanceQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(instanceQuery),
@@ -108,10 +112,9 @@ func Test_InstancePrepares(t *testing.T) {
object: (*Instance)(nil),
},
{
name: "prepareInstanceQuery found",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*Instance, error)) {
return prepareInstanceQuery("")
},
name: "prepareInstanceQuery found",
additionalArgs: []reflect.Value{reflect.ValueOf("")},
prepare: prepareInstanceQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(instanceQuery),
@@ -142,10 +145,9 @@ func Test_InstancePrepares(t *testing.T) {
},
},
{
name: "prepareInstanceQuery sql err",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*Instance, error)) {
return prepareInstanceQuery("")
},
name: "prepareInstanceQuery sql err",
additionalArgs: []reflect.Value{reflect.ValueOf("")},
prepare: prepareInstanceQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(instanceQuery),
@@ -162,8 +164,8 @@ func Test_InstancePrepares(t *testing.T) {
},
{
name: "prepareInstancesQuery no result",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery()
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery(ctx, db)
return query(filter), scan
},
want: want{
@@ -177,8 +179,8 @@ func Test_InstancePrepares(t *testing.T) {
},
{
name: "prepareInstancesQuery one result",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery()
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery(ctx, db)
return query(filter), scan
},
want: want{
@@ -241,8 +243,8 @@ func Test_InstancePrepares(t *testing.T) {
},
{
name: "prepareInstancesQuery multiple results",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery()
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery(ctx, db)
return query(filter), scan
},
want: want{
@@ -374,8 +376,8 @@ func Test_InstancePrepares(t *testing.T) {
},
{
name: "prepareInstancesQuery sql err",
prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery()
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
filter, query, scan := prepareInstancesQuery(ctx, db)
return query(filter), scan
},
want: want{
@@ -395,7 +397,7 @@ func Test_InstancePrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, append(defaultPrepareArgs, tt.additionalArgs...)...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
@@ -180,7 +181,7 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicK
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := preparePublicKeysQuery()
query, scan := preparePublicKeysQuery(ctx, q.client)
if t.IsZero() {
t = time.Now()
}
@@ -212,7 +213,7 @@ func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (_ *
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := preparePrivateKeysQuery()
stmt, scan := preparePrivateKeysQuery(ctx, q.client)
if t.IsZero() {
t = time.Now()
}
@@ -243,7 +244,7 @@ func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (_ *
return keys, nil
}
func preparePublicKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, error)) {
func preparePublicKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, error)) {
return sq.Select(
KeyColID.identifier(),
KeyColCreationDate.identifier(),
@@ -256,7 +257,7 @@ func preparePublicKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, e
KeyPublicColKey.identifier(),
countColumn.identifier(),
).From(keyTable.identifier()).
LeftJoin(join(KeyPublicColID, KeyColID)).
LeftJoin(join(KeyPublicColID, KeyColID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*PublicKeys, error) {
keys := make([]PublicKey, 0)
@@ -299,7 +300,7 @@ func preparePublicKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, e
}
}
func preparePrivateKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PrivateKeys, error)) {
func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*PrivateKeys, error)) {
return sq.Select(
KeyColID.identifier(),
KeyColCreationDate.identifier(),
@@ -312,7 +313,7 @@ func preparePrivateKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PrivateKeys,
KeyPrivateColKey.identifier(),
countColumn.identifier(),
).From(keyTable.identifier()).
LeftJoin(join(KeyPrivateColID, KeyColID)).
LeftJoin(join(KeyPrivateColID, KeyColID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*PrivateKeys, error) {
keys := make([]PrivateKey, 0)

View File

@@ -15,6 +15,48 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
preparePublicKeysStmt = `SELECT projections.keys4.id,` +
` projections.keys4.creation_date,` +
` projections.keys4.change_date,` +
` projections.keys4.sequence,` +
` projections.keys4.resource_owner,` +
` projections.keys4.algorithm,` +
` projections.keys4.use,` +
` projections.keys4_public.expiry,` +
` projections.keys4_public.key,` +
` COUNT(*) OVER ()` +
` FROM projections.keys4` +
` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id AND projections.keys4.instance_id = projections.keys4_public.instance_id` +
` AS OF SYSTEM TIME '-1 ms' `
preparePublicKeysCols = []string{
"id",
"creation_date",
"change_date",
"sequence",
"resource_owner",
"algorithm",
"use",
"expiry",
"key",
"count",
}
preparePrivateKeysStmt = `SELECT projections.keys4.id,` +
` projections.keys4.creation_date,` +
` projections.keys4.change_date,` +
` projections.keys4.sequence,` +
` projections.keys4.resource_owner,` +
` projections.keys4.algorithm,` +
` projections.keys4.use,` +
` projections.keys4_private.expiry,` +
` projections.keys4_private.key,` +
` COUNT(*) OVER ()` +
` FROM projections.keys4` +
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id` +
` AS OF SYSTEM TIME '-1 ms' `
)
func Test_KeyPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -31,18 +73,7 @@ func Test_KeyPrepares(t *testing.T) {
prepare: preparePublicKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_public.expiry,`+
` projections.keys4_public.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id`),
regexp.QuoteMeta(preparePublicKeysStmt),
nil,
nil,
),
@@ -60,30 +91,8 @@ func Test_KeyPrepares(t *testing.T) {
prepare: preparePublicKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_public.expiry,`+
` projections.keys4_public.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id`),
[]string{
"id",
"creation_date",
"change_date",
"sequence",
"resource_owner",
"algorithm",
"use",
"expiry",
"key",
"count",
},
regexp.QuoteMeta(preparePublicKeysStmt),
preparePublicKeysCols,
[][]driver.Value{
{
"key-id",
@@ -128,18 +137,7 @@ func Test_KeyPrepares(t *testing.T) {
prepare: preparePublicKeysQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_public.expiry,`+
` projections.keys4_public.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id`),
regexp.QuoteMeta(preparePublicKeysStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -156,18 +154,7 @@ func Test_KeyPrepares(t *testing.T) {
prepare: preparePrivateKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_private.expiry,`+
` projections.keys4_private.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id`),
regexp.QuoteMeta(preparePrivateKeysStmt),
nil,
nil,
),
@@ -185,30 +172,8 @@ func Test_KeyPrepares(t *testing.T) {
prepare: preparePrivateKeysQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_private.expiry,`+
` projections.keys4_private.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id`),
[]string{
"id",
"creation_date",
"change_date",
"sequence",
"resource_owner",
"algorithm",
"use",
"expiry",
"key",
"count",
},
regexp.QuoteMeta(preparePrivateKeysStmt),
preparePublicKeysCols,
[][]driver.Value{
{
"key-id",
@@ -255,18 +220,7 @@ func Test_KeyPrepares(t *testing.T) {
prepare: preparePrivateKeysQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.keys4.id,`+
` projections.keys4.creation_date,`+
` projections.keys4.change_date,`+
` projections.keys4.sequence,`+
` projections.keys4.resource_owner,`+
` projections.keys4.algorithm,`+
` projections.keys4.use,`+
` projections.keys4_private.expiry,`+
` projections.keys4_private.key,`+
` COUNT(*) OVER ()`+
` FROM projections.keys4`+
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id`),
regexp.QuoteMeta(preparePrivateKeysStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -281,7 +235,7 @@ func Test_KeyPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -45,7 +46,7 @@ func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, with
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareLabelPolicyQuery()
stmt, scan := prepareLabelPolicyQuery(ctx, q.client)
eq := sq.Eq{
LabelPolicyColState.identifier(): domain.LabelPolicyStateActive,
LabelPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -75,7 +76,7 @@ func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (_
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareLabelPolicyQuery()
stmt, scan := prepareLabelPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(
sq.And{
sq.Or{
@@ -105,7 +106,7 @@ func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (_ *LabelPolicy,
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareLabelPolicyQuery()
stmt, scan := prepareLabelPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
LabelPolicyColState.identifier(): domain.LabelPolicyStateActive,
@@ -125,7 +126,7 @@ func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (_ *LabelPolicy
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareLabelPolicyQuery()
stmt, scan := prepareLabelPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
LabelPolicyColState.identifier(): domain.LabelPolicyStatePreview,
@@ -223,7 +224,7 @@ var (
}
)
func prepareLabelPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*LabelPolicy, error)) {
func prepareLabelPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LabelPolicy, error)) {
return sq.Select(
LabelPolicyColCreationDate.identifier(),
LabelPolicyColChangeDate.identifier(),
@@ -252,7 +253,8 @@ func prepareLabelPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*LabelPolicy,
LabelPolicyColDarkLogoURL.identifier(),
LabelPolicyColDarkIconURL.identifier(),
).
From(labelPolicyTable.identifier()).PlaceholderFormat(sq.Dollar),
From(labelPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*LabelPolicy, error) {
policy := new(LabelPolicy)

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -94,7 +95,7 @@ func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool
eq[LockoutPolicyOwnerRemoved.identifier()] = false
}
stmt, scan := prepareLockoutPolicyQuery()
stmt, scan := prepareLockoutPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(
sq.And{
eq,
@@ -117,7 +118,7 @@ func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (_ *LockoutPolicy, e
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareLockoutPolicyQuery()
stmt, scan := prepareLockoutPolicyQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
LockoutColID.identifier(): authz.GetInstance(ctx).InstanceID(),
LockoutColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -132,7 +133,7 @@ func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (_ *LockoutPolicy, e
return scan(row)
}
func prepareLockoutPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) {
func prepareLockoutPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) {
return sq.Select(
LockoutColID.identifier(),
LockoutColSequence.identifier(),
@@ -144,7 +145,8 @@ func prepareLockoutPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*LockoutPoli
LockoutColIsDefault.identifier(),
LockoutColState.identifier(),
).
From(lockoutTable.identifier()).PlaceholderFormat(sq.Dollar),
From(lockoutTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*LockoutPolicy, error) {
policy := new(LockoutPolicy)
err := row.Scan(

View File

@@ -12,6 +12,32 @@ import (
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
prepareLockoutPolicyStmt = `SELECT projections.lockout_policies2.id,` +
` projections.lockout_policies2.sequence,` +
` projections.lockout_policies2.creation_date,` +
` projections.lockout_policies2.change_date,` +
` projections.lockout_policies2.resource_owner,` +
` projections.lockout_policies2.show_failure,` +
` projections.lockout_policies2.max_password_attempts,` +
` projections.lockout_policies2.is_default,` +
` projections.lockout_policies2.state` +
` FROM projections.lockout_policies2` +
` AS OF SYSTEM TIME '-1 ms'`
prepareLockoutPolicyCols = []string{
"id",
"sequence",
"creation_date",
"change_date",
"resource_owner",
"show_failure",
"max_password_attempts",
"is_default",
"state",
}
)
func Test_LockoutPolicyPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
@@ -28,16 +54,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) {
prepare: prepareLockoutPolicyQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.lockout_policies2.id,`+
` projections.lockout_policies2.sequence,`+
` projections.lockout_policies2.creation_date,`+
` projections.lockout_policies2.change_date,`+
` projections.lockout_policies2.resource_owner,`+
` projections.lockout_policies2.show_failure,`+
` projections.lockout_policies2.max_password_attempts,`+
` projections.lockout_policies2.is_default,`+
` projections.lockout_policies2.state`+
` FROM projections.lockout_policies2`),
regexp.QuoteMeta(prepareLockoutPolicyStmt),
nil,
nil,
),
@@ -55,27 +72,8 @@ func Test_LockoutPolicyPrepares(t *testing.T) {
prepare: prepareLockoutPolicyQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.lockout_policies2.id,`+
` projections.lockout_policies2.sequence,`+
` projections.lockout_policies2.creation_date,`+
` projections.lockout_policies2.change_date,`+
` projections.lockout_policies2.resource_owner,`+
` projections.lockout_policies2.show_failure,`+
` projections.lockout_policies2.max_password_attempts,`+
` projections.lockout_policies2.is_default,`+
` projections.lockout_policies2.state`+
` FROM projections.lockout_policies2`),
[]string{
"id",
"sequence",
"creation_date",
"change_date",
"resource_owner",
"show_failure",
"max_password_attempts",
"is_default",
"state",
},
regexp.QuoteMeta(prepareLockoutPolicyStmt),
prepareLockoutPolicyCols,
[]driver.Value{
"pol-id",
uint64(20211109),
@@ -106,16 +104,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) {
prepare: prepareLockoutPolicyQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.lockout_policies2.id,`+
` projections.lockout_policies2.sequence,`+
` projections.lockout_policies2.creation_date,`+
` projections.lockout_policies2.change_date,`+
` projections.lockout_policies2.resource_owner,`+
` projections.lockout_policies2.show_failure,`+
` projections.lockout_policies2.max_password_attempts,`+
` projections.lockout_policies2.is_default,`+
` projections.lockout_policies2.state`+
` FROM projections.lockout_policies2`),
regexp.QuoteMeta(prepareLockoutPolicyStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -130,7 +119,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
@@ -172,7 +173,7 @@ func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, o
eq[LoginPolicyColumnOwnerRemoved.identifier()] = false
}
query, scan := prepareLoginPolicyQuery()
query, scan := prepareLoginPolicyQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.And{
eq,
@@ -212,7 +213,7 @@ func (q *Queries) DefaultLoginPolicy(ctx context.Context) (_ *LoginPolicy, err e
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginPolicyQuery()
query, scan := prepareLoginPolicyQuery(ctx, q.client)
stmt, args, err := query.Where(sq.Eq{
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -232,7 +233,7 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *Seco
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginPolicy2FAsQuery()
query, scan := prepareLoginPolicy2FAsQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.And{
sq.Eq{
@@ -266,7 +267,7 @@ func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, e
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginPolicy2FAsQuery()
query, scan := prepareLoginPolicy2FAsQuery(ctx, q.client)
stmt, args, err := query.Where(sq.Eq{
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -288,7 +289,7 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *Multi
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginPolicyMFAsQuery()
query, scan := prepareLoginPolicyMFAsQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.And{
sq.Eq{
@@ -322,7 +323,7 @@ func (q *Queries) DefaultMultiFactors(ctx context.Context) (_ *MultiFactors, err
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginPolicyMFAsQuery()
query, scan := prepareLoginPolicyMFAsQuery(ctx, q.client)
stmt, args, err := query.Where(sq.Eq{
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -340,7 +341,7 @@ func (q *Queries) DefaultMultiFactors(ctx context.Context) (_ *MultiFactors, err
return factors, err
}
func prepareLoginPolicyQuery() (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy, error)) {
func prepareLoginPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy, error)) {
return sq.Select(
LoginPolicyColumnOrgID.identifier(),
LoginPolicyColumnCreationDate.identifier(),
@@ -365,7 +366,7 @@ func prepareLoginPolicyQuery() (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy,
LoginPolicyColumnMFAInitSkipLifetime.identifier(),
LoginPolicyColumnSecondFactorCheckLifetime.identifier(),
LoginPolicyColumnMultiFactorCheckLifetime.identifier(),
).From(loginPolicyTable.identifier()).
).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*LoginPolicy, error) {
p := new(LoginPolicy)
@@ -408,10 +409,11 @@ func prepareLoginPolicyQuery() (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy,
}
}
func prepareLoginPolicy2FAsQuery() (sq.SelectBuilder, func(*sql.Row) (*SecondFactors, error)) {
func prepareLoginPolicy2FAsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SecondFactors, error)) {
return sq.Select(
LoginPolicyColumnSecondFactors.identifier(),
).From(loginPolicyTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*SecondFactors, error) {
p := new(SecondFactors)
err := row.Scan(
@@ -429,10 +431,11 @@ func prepareLoginPolicy2FAsQuery() (sq.SelectBuilder, func(*sql.Row) (*SecondFac
}
}
func prepareLoginPolicyMFAsQuery() (sq.SelectBuilder, func(*sql.Row) (*MultiFactors, error)) {
func prepareLoginPolicyMFAsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MultiFactors, error)) {
return sq.Select(
LoginPolicyColumnMultiFactors.identifier(),
).From(loginPolicyTable.identifier()).PlaceholderFormat(sq.Dollar),
).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*MultiFactors, error) {
p := new(MultiFactors)
err := row.Scan(

View File

@@ -38,7 +38,8 @@ var (
` projections.login_policies4.mfa_init_skip_lifetime,` +
` projections.login_policies4.second_factor_check_lifetime,` +
` projections.login_policies4.multi_factor_check_lifetime` +
` FROM projections.login_policies4`
` FROM projections.login_policies4` +
` AS OF SYSTEM TIME '-1 ms'`
loginPolicyCols = []string{
"aggregate_id",
"creation_date",
@@ -64,6 +65,20 @@ var (
"second_factor_check_lifetime",
"multi_factor_check_lifetime",
}
prepareLoginPolicy2FAsStmt = `SELECT projections.login_policies4.second_factors` +
` FROM projections.login_policies4` +
` AS OF SYSTEM TIME '-1 ms'`
prepareLoginPolicy2FAsCols = []string{
"second_factors",
}
prepareLoginPolicyMFAsStmt = `SELECT projections.login_policies4.multi_factors` +
` FROM projections.login_policies4` +
` AS OF SYSTEM TIME '-1 ms'`
prepareLoginPolicyMFAsCols = []string{
"multi_factors",
}
)
func Test_LoginPolicyPrepares(t *testing.T) {
@@ -177,11 +192,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicy2FAsQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.login_policies4.second_factors`+
` FROM projections.login_policies4`),
[]string{
"second_factors",
},
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
prepareLoginPolicy2FAsCols,
nil,
),
err: func(err error) (error, bool) {
@@ -198,11 +210,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicy2FAsQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.login_policies4.second_factors`+
` FROM projections.login_policies4`),
[]string{
"second_factors",
},
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
prepareLoginPolicy2FAsCols,
[]driver.Value{
database.EnumArray[domain.SecondFactorType]{domain.SecondFactorTypeOTP},
},
@@ -220,11 +229,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicy2FAsQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.login_policies4.second_factors`+
` FROM projections.login_policies4`),
[]string{
"second_factors",
},
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
prepareLoginPolicy2FAsCols,
[]driver.Value{
database.EnumArray[domain.SecondFactorType]{},
},
@@ -237,8 +243,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicy2FAsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.login_policies4.second_factors`+
` FROM projections.login_policies4`),
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -255,11 +260,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicyMFAsQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.login_policies4.multi_factors`+
` FROM projections.login_policies4`),
[]string{
"multi_factors",
},
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
prepareLoginPolicyMFAsCols,
nil,
),
err: func(err error) (error, bool) {
@@ -276,11 +278,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicyMFAsQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.login_policies4.multi_factors`+
` FROM projections.login_policies4`),
[]string{
"multi_factors",
},
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
prepareLoginPolicyMFAsCols,
[]driver.Value{
database.EnumArray[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
},
@@ -298,11 +297,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicyMFAsQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(`SELECT projections.login_policies4.multi_factors`+
` FROM projections.login_policies4`),
[]string{
"multi_factors",
},
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
prepareLoginPolicyMFAsCols,
[]driver.Value{
database.EnumArray[domain.MultiFactorType]{},
},
@@ -315,8 +311,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
prepare: prepareLoginPolicyMFAsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.login_policies4.multi_factors`+
` FROM projections.login_policies4`),
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
@@ -331,7 +326,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -73,7 +74,7 @@ func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwner
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareMailTemplateQuery()
stmt, scan := prepareMailTemplateQuery(ctx, q.client)
eq := sq.Eq{MailTemplateColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[MailTemplateColOwnerRemoved.identifier()] = false
@@ -100,7 +101,7 @@ func (q *Queries) DefaultMailTemplate(ctx context.Context) (_ *MailTemplate, err
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareMailTemplateQuery()
stmt, scan := prepareMailTemplateQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
MailTemplateColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
MailTemplateColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -115,7 +116,7 @@ func (q *Queries) DefaultMailTemplate(ctx context.Context) (_ *MailTemplate, err
return scan(row)
}
func prepareMailTemplateQuery() (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) {
func prepareMailTemplateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) {
return sq.Select(
MailTemplateColAggregateID.identifier(),
MailTemplateColSequence.identifier(),
@@ -125,7 +126,8 @@ func prepareMailTemplateQuery() (sq.SelectBuilder, func(*sql.Row) (*MailTemplate
MailTemplateColIsDefault.identifier(),
MailTemplateColState.identifier(),
).
From(mailTemplateTable.identifier()).PlaceholderFormat(sq.Dollar),
From(mailTemplateTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*MailTemplate, error) {
policy := new(MailTemplate)
err := row.Scan(

View File

@@ -16,6 +16,7 @@ import (
"sigs.k8s.io/yaml"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -127,7 +128,7 @@ func (q *Queries) DefaultMessageText(ctx context.Context) (_ *MessageText, err e
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareMessageTextQuery()
stmt, scan := prepareMessageTextQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
MessageTextColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
MessageTextColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -160,7 +161,7 @@ func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggreg
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareMessageTextQuery()
stmt, scan := prepareMessageTextQuery(ctx, q.client)
eq := sq.Eq{
MessageTextColLanguage.identifier(): language,
MessageTextColType.identifier(): messageType,
@@ -240,7 +241,7 @@ func (q *Queries) readNotificationTextMessages(ctx context.Context, language str
return contents, nil
}
func prepareMessageTextQuery() (sq.SelectBuilder, func(*sql.Row) (*MessageText, error)) {
func prepareMessageTextQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MessageText, error)) {
return sq.Select(
MessageTextColAggregateID.identifier(),
MessageTextColSequence.identifier(),
@@ -257,7 +258,8 @@ func prepareMessageTextQuery() (sq.SelectBuilder, func(*sql.Row) (*MessageText,
MessageTextColButtonText.identifier(),
MessageTextColFooter.identifier(),
).
From(messageTextTable.identifier()).PlaceholderFormat(sq.Dollar),
From(messageTextTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*MessageText, error) {
msg := new(MessageText)
lang := ""

Some files were not shown because too many files have changed in this diff Show More