feat: projections auto create their tables (#3324)

* begin init checks for projections

* first projection checks

* debug notification providers with query fixes

* more projections and first index

* more projections

* more projections

* finish projections

* fix tests (remove db name)

* create tables in setup

* fix logging / error handling

* add tenant to views

* rename tenant to instance_id

* add instance_id to all projections

* add instance_id to all queries

* correct instance_id on projections

* add instance_id to failed_events

* use separate context for instance

* implement features projection

* implement features projection

* remove unique constraint from setup when migration failed

* add error to failed setup event

* add instance_id to primary keys

* fix IAM projection

* remove old migrations folder

* fix keysFromYAML test
This commit is contained in:
Livio Amstutz 2022-03-23 09:02:39 +01:00 committed by GitHub
parent 9e13b70a3d
commit 56b916a2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
400 changed files with 6508 additions and 8890 deletions

View File

@ -12,13 +12,13 @@ CREATE TABLE eventstore.events (
, editor_user TEXT NOT NULL
, editor_service TEXT NOT NULL
, resource_owner TEXT NOT NULL
, tenant TEXT
, instance_id TEXT
, PRIMARY KEY (event_sequence DESC) USING HASH WITH BUCKET_COUNT = 10
, INDEX agg_type_agg_id (aggregate_type, aggregate_id)
, INDEX agg_type (aggregate_type)
, INDEX agg_type_seq (aggregate_type, event_sequence DESC)
STORING (id, event_type, aggregate_id, aggregate_version, previous_aggregate_sequence, creation_date, event_data, editor_user, editor_service, resource_owner, tenant, previous_aggregate_type_sequence)
STORING (id, event_type, aggregate_id, aggregate_version, previous_aggregate_sequence, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_type_sequence)
, INDEX max_sequence (aggregate_type, aggregate_id, event_sequence DESC)
, CONSTRAINT previous_sequence_unique UNIQUE (previous_aggregate_sequence DESC)
, CONSTRAINT prev_agg_type_seq_unique UNIQUE(previous_aggregate_type_sequence)

View File

@ -155,7 +155,7 @@ func Test_keysFromYAML(t *testing.T) {
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
assert.EqualValues(t, got, tt.res.keys)
assert.ElementsMatch(t, got, tt.res.keys)
})
}
}

View File

@ -0,0 +1,34 @@
package setup
import (
"context"
"database/sql"
_ "embed"
)
var (
//go:embed 01_sql/adminapi.sql
createAdminViews string
//go:embed 01_sql/auth.sql
createAuthViews string
//go:embed 01_sql/authz.sql
createAuthzViews string
//go:embed 01_sql/notification.sql
createNotificationViews string
//go:embed 01_sql/projections.sql
createProjections string
)
type ProjectionTable struct {
dbClient *sql.DB
}
func (mig *ProjectionTable) Execute(ctx context.Context) error {
stmt := createAdminViews + createAuthViews + createAuthzViews + createNotificationViews + createProjections
_, err := mig.dbClient.ExecContext(ctx, stmt)
return err
}
func (mig *ProjectionTable) String() string {
return "01_tables"
}

View File

@ -0,0 +1,54 @@
CREATE SCHEMA adminapi;
CREATE TABLE adminapi.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE adminapi.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE adminapi.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE adminapi.styling (
aggregate_id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
label_policy_state INT2 NOT NULL DEFAULT 0:::INT2,
sequence INT8 NULL,
primary_color STRING NULL,
background_color STRING NULL,
warn_color STRING NULL,
font_color STRING NULL,
primary_color_dark STRING NULL,
background_color_dark STRING NULL,
warn_color_dark STRING NULL,
font_color_dark STRING NULL,
logo_url STRING NULL,
icon_url STRING NULL,
logo_dark_url STRING NULL,
icon_dark_url STRING NULL,
font_url STRING NULL,
err_msg_popup BOOL NULL,
disable_watermark BOOL NULL,
hide_login_name_suffix BOOL NULL,
instance_id STRING NOT NULL,
PRIMARY KEY (aggregate_id, label_policy_state)
);

View File

@ -0,0 +1,222 @@
CREATE SCHEMA auth;
CREATE TABLE auth.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE auth.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE auth.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE auth.users (
id STRING NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
user_state INT2 NULL,
password_set BOOL NULL,
password_change_required BOOL NULL,
password_change TIMESTAMPTZ NULL,
last_login TIMESTAMPTZ NULL,
user_name STRING NULL,
login_names STRING[] NULL,
preferred_login_name STRING NULL,
first_name STRING NULL,
last_name STRING NULL,
nick_name STRING NULL,
display_name STRING NULL,
preferred_language STRING NULL,
gender INT2 NULL,
email STRING NULL,
is_email_verified BOOL NULL,
phone STRING NULL,
is_phone_verified BOOL NULL,
country STRING NULL,
locality STRING NULL,
postal_code STRING NULL,
region STRING NULL,
street_address STRING NULL,
otp_state INT2 NULL,
mfa_max_set_up INT2 NULL,
mfa_init_skipped TIMESTAMPTZ NULL,
sequence INT8 NULL,
init_required BOOL NULL,
username_change_required BOOL NULL,
machine_name STRING NULL,
machine_description STRING NULL,
user_type STRING NULL,
u2f_tokens BYTES NULL,
passwordless_tokens BYTES NULL,
avatar_key STRING NULL,
passwordless_init_required BOOL NULL,
password_init_required BOOL NULL,
instance_id STRING NULL,
PRIMARY KEY (id)
);
CREATE TABLE auth.user_sessions (
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
state INT2 NULL,
user_agent_id STRING NULL,
user_id STRING NULL,
user_name STRING NULL,
password_verification TIMESTAMPTZ NULL,
second_factor_verification TIMESTAMPTZ NULL,
multi_factor_verification TIMESTAMPTZ NULL,
sequence INT8 NULL,
second_factor_verification_type INT2 NULL,
multi_factor_verification_type INT2 NULL,
user_display_name STRING NULL,
login_name STRING NULL,
external_login_verification TIMESTAMPTZ NULL,
selected_idp_config_id STRING NULL,
passwordless_verification TIMESTAMPTZ NULL,
avatar_key STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (user_agent_id, user_id)
);
CREATE TABLE auth.user_external_idps (
external_user_id STRING NOT NULL,
idp_config_id STRING NOT NULL,
user_id STRING NULL,
idp_name STRING NULL,
user_display_name STRING NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
resource_owner STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (external_user_id, idp_config_id)
);
CREATE TABLE auth.tokens (
id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
application_id STRING NULL,
user_agent_id STRING NULL,
user_id STRING NULL,
expiration TIMESTAMPTZ NULL,
sequence INT8 NULL,
scopes STRING[] NULL,
audience STRING[] NULL,
preferred_language STRING NULL,
refresh_token_id STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (id),
INDEX user_user_agent_idx (user_id, user_agent_id)
);
CREATE TABLE auth.refresh_tokens (
id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
token STRING NULL,
client_id STRING NOT NULL,
user_agent_id STRING NOT NULL,
user_id STRING NOT NULL,
auth_time TIMESTAMPTZ NULL,
idle_expiration TIMESTAMPTZ NULL,
expiration TIMESTAMPTZ NULL,
sequence INT8 NULL,
scopes STRING[] NULL,
audience STRING[] NULL,
amr STRING[] NULL,
instance_id STRING NULL,
PRIMARY KEY (id),
UNIQUE INDEX unique_client_user_index (client_id ASC, user_agent_id ASC, user_id ASC)
);
CREATE TABLE auth.org_project_mapping (
org_id STRING NOT NULL,
project_id STRING NOT NULL,
project_grant_id STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (org_id, project_id)
);
CREATE TABLE auth.idp_providers (
aggregate_id STRING NOT NULL,
idp_config_id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
name STRING NULL,
idp_config_type INT2 NULL,
idp_provider_type INT2 NULL,
idp_state INT2 NULL,
styling_type INT2 NULL,
instance_id STRING NULL,
PRIMARY KEY (aggregate_id, idp_config_id)
);
CREATE TABLE auth.idp_configs (
idp_config_id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
aggregate_id STRING NULL,
name STRING NULL,
idp_state INT2 NULL,
idp_provider_type INT2 NULL,
is_oidc BOOL NULL,
oidc_client_id STRING NULL,
oidc_client_secret JSONB NULL,
oidc_issuer STRING NULL,
oidc_scopes STRING[] NULL,
oidc_idp_display_name_mapping INT2 NULL,
oidc_idp_username_mapping INT2 NULL,
styling_type INT2 NULL,
oauth_authorization_endpoint STRING NULL,
oauth_token_endpoint STRING NULL,
auto_register BOOL NULL,
jwt_endpoint STRING NULL,
jwt_keys_endpoint STRING NULL,
jwt_header_name STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (idp_config_id)
);
CREATE TABLE auth.auth_requests (
id STRING NOT NULL,
request JSONB NULL,
code STRING NULL,
request_type INT2 NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
instance_id STRING NULL,
PRIMARY KEY (id),
INDEX auth_code_idx (code)
);

View File

@ -0,0 +1,44 @@
CREATE SCHEMA authz;
CREATE TABLE authz.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE authz.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE authz.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE authz.user_memberships (
user_id STRING NOT NULL,
member_type INT2 NOT NULL,
aggregate_id STRING NOT NULL,
object_id STRING NOT NULL,
roles STRING[] NULL,
display_name STRING NULL,
resource_owner STRING NULL,
resource_owner_name STRING NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
instance_id STRING NULL,
PRIMARY KEY (user_id, member_type, aggregate_id, object_id)
);

View File

@ -0,0 +1,52 @@
CREATE SCHEMA notification;
CREATE TABLE notification.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE notification.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE notification.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE notification.notify_users (
id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
user_name STRING NULL,
first_name STRING NULL,
last_name STRING NULL,
nick_name STRING NULL,
display_name STRING NULL,
preferred_language STRING NULL,
gender INT2 NULL,
last_email STRING NULL,
verified_email STRING NULL,
last_phone STRING NULL,
verified_phone STRING NULL,
sequence INT8 NULL,
password_set BOOL NULL,
login_names STRING NULL,
preferred_login_name STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (id)
);

View File

@ -1,10 +1,4 @@
CREATE DATABASE zitadel;
GRANT SELECT, INSERT, UPDATE, DELETE ON DATABASE zitadel TO queries;
use zitadel;
CREATE SCHEMA zitadel.projections AUTHORIZATION queries;
CREATE TABLE zitadel.projections.locks (
CREATE TABLE projections.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
@ -12,7 +6,7 @@ CREATE TABLE zitadel.projections.locks (
PRIMARY KEY (projection_name)
);
CREATE TABLE zitadel.projections.current_sequences (
CREATE TABLE projections.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
@ -21,11 +15,12 @@ CREATE TABLE zitadel.projections.current_sequences (
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE zitadel.projections.failed_events (
CREATE TABLE projections.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
instance_id TEXT,
PRIMARY KEY (projection_name, failed_sequence)
PRIMARY KEY (projection_name, failed_sequence, instance_id)
);

13
cmd/admin/setup/config.go Normal file
View File

@ -0,0 +1,13 @@
package setup
import (
"github.com/caos/zitadel/internal/database"
)
type Config struct {
Database database.Config
}
type Steps struct {
S1ProjectionTable *ProjectionTable
}

View File

@ -1,10 +1,22 @@
package setup
import (
"bytes"
"context"
_ "embed"
"github.com/caos/logging"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/caos/zitadel/internal/database"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/migration"
)
var (
//go:embed steps.yaml
defaultSteps []byte
)
func New() *cobra.Command {
@ -14,9 +26,33 @@ func New() *cobra.Command {
Long: `sets up data to start ZITADEL.
Requirements:
- cockroachdb`,
RunE: func(cmd *cobra.Command, args []string) error {
logging.Info("hello world")
return nil
Run: func(cmd *cobra.Command, args []string) {
config := new(Config)
err := viper.Unmarshal(config)
logging.OnError(err).Fatal("unable to read config")
v := viper.New()
v.SetConfigType("yaml")
err = v.ReadConfig(bytes.NewBuffer(defaultSteps))
logging.OnError(err).Fatal("unable to read setup steps")
steps := new(Steps)
err = v.Unmarshal(steps)
logging.OnError(err).Fatal("unable to read steps")
setup(config, steps)
},
}
}
func setup(config *Config, steps *Steps) {
dbClient, err := database.Connect(config.Database)
logging.OnError(err).Fatal("unable to connect to database")
eventstoreClient, err := eventstore.Start(dbClient)
logging.OnError(err).Fatal("unable to start eventstore")
steps.S1ProjectionTable = &ProjectionTable{dbClient: dbClient}
migration.Migrate(context.Background(), eventstoreClient, steps.S1ProjectionTable)
}

View File

@ -41,7 +41,6 @@ import (
"github.com/caos/zitadel/internal/crypto"
cryptoDB "github.com/caos/zitadel/internal/crypto/database"
"github.com/caos/zitadel/internal/database"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/id"
"github.com/caos/zitadel/internal/notification"
@ -308,7 +307,7 @@ func shutdownServer(ctx context.Context, server *http.Server) error {
//TODO:!!??!!
func consoleClientID(ctx context.Context, queries *query.Queries) (string, error) {
iam, err := queries.IAMByID(ctx, domain.IAMID)
iam, err := queries.IAM(ctx)
if err != nil {
return "", err
}

View File

@ -32,7 +32,7 @@ type API struct {
type health interface {
Health(ctx context.Context) error
IAMByID(ctx context.Context, id string) (*query.IAM, error)
IAM(ctx context.Context) (*query.IAM, error)
}
func New(
@ -107,7 +107,7 @@ func (a *API) healthHandler() http.Handler {
return nil
},
func(ctx context.Context) error {
iam, err := a.health.IAMByID(ctx, domain.IAMID)
iam, err := a.health.IAM(ctx)
if err != nil && !errors.IsNotFound(err) {
return errors.ThrowPreconditionFailed(err, "API-dsgT2", "IAM SETUP CHECK FAILED")
}

View File

@ -15,12 +15,12 @@ const (
requestPermissionsKey key = 1
dataKey key = 2
allPermissionsKey key = 3
instanceKey key = 4
)
type CtxData struct {
UserID string
OrgID string
TenantID string //TODO: Set Tenant ID on some context
ProjectID string
AgentID string
PreferredLanguage string
@ -31,6 +31,10 @@ func (ctxData CtxData) IsZero() bool {
return ctxData.UserID == "" || ctxData.OrgID == ""
}
type Instance struct {
ID string
}
type Grants []*Grant
type Grant struct {
@ -43,7 +47,7 @@ type Memberships []*Membership
type Membership struct {
MemberType MemberType
AggregateID string
//ObjectID differs from aggregate id if obejct is sub of an aggregate
//ObjectID differs from aggregate id if object is sub of an aggregate
ObjectID string
Roles []string
@ -112,6 +116,11 @@ func GetCtxData(ctx context.Context) CtxData {
return ctxData
}
func GetInstance(ctx context.Context) Instance {
instance, _ := ctx.Value(instanceKey).(Instance)
return instance
}
func GetRequestPermissionsFromCtx(ctx context.Context) []string {
ctxPermission, _ := ctx.Value(requestPermissionsKey).([]string)
return ctxPermission

View File

@ -2,11 +2,13 @@ package authz
import "context"
func NewMockContext(tenantID, orgID, userID string) context.Context {
return context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID, TenantID: tenantID})
func NewMockContext(instanceID, orgID, userID string) context.Context {
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
return context.WithValue(ctx, instanceKey, instanceID)
}
func NewMockContextWithPermissions(tenantID, orgID, userID string, permissions []string) context.Context {
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID, TenantID: tenantID})
func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context {
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
ctx = context.WithValue(ctx, instanceKey, instanceID)
return context.WithValue(ctx, requestPermissionsKey, permissions)
}

View File

@ -3,12 +3,12 @@ package admin
import (
"context"
"golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/grpc/object"
"github.com/caos/zitadel/internal/api/grpc/text"
"github.com/caos/zitadel/internal/domain"
caos_errors "github.com/caos/zitadel/internal/errors"
admin_pb "github.com/caos/zitadel/pkg/grpc/admin"
"golang.org/x/text/language"
)
func (s *Server) GetSupportedLanguages(ctx context.Context, req *admin_pb.GetSupportedLanguagesRequest) (*admin_pb.GetSupportedLanguagesResponse, error) {
@ -34,9 +34,5 @@ func (s *Server) SetDefaultLanguage(ctx context.Context, req *admin_pb.SetDefaul
}
func (s *Server) GetDefaultLanguage(ctx context.Context, req *admin_pb.GetDefaultLanguageRequest) (*admin_pb.GetDefaultLanguageResponse, error) {
iam, err := s.query.IAMByID(ctx, domain.IAMID)
if err != nil {
return nil, err
}
return &admin_pb.GetDefaultLanguageResponse{Language: iam.DefaultLanguage.String()}, nil
return &admin_pb.GetDefaultLanguageResponse{Language: s.query.GetDefaultLanguage(ctx).String()}, nil
}

View File

@ -152,7 +152,7 @@ func (s *Server) ListMyProjectOrgs(ctx context.Context, req *auth_pb.ListMyProje
return nil, err
}
iam, err := s.query.IAMByID(ctx, domain.IAMID)
iam, err := s.query.IAM(ctx)
if err != nil {
return nil, err
}

View File

@ -3,12 +3,11 @@ package management
import (
"context"
"github.com/caos/zitadel/internal/domain"
mgmt_pb "github.com/caos/zitadel/pkg/grpc/management"
)
func (s *Server) GetIAM(ctx context.Context, req *mgmt_pb.GetIAMRequest) (*mgmt_pb.GetIAMResponse, error) {
iam, err := s.query.IAMByID(ctx, domain.IAMID)
func (s *Server) GetIAM(ctx context.Context, _ *mgmt_pb.GetIAMRequest) (*mgmt_pb.GetIAMResponse, error) {
iam, err := s.query.IAM(ctx)
if err != nil {
return nil, err
}

View File

@ -206,8 +206,8 @@ func (s *Server) SetPrimaryOrgDomain(ctx context.Context, req *mgmt_pb.SetPrimar
}, nil
}
func (s *Server) ListOrgMemberRoles(ctx context.Context, req *mgmt_pb.ListOrgMemberRolesRequest) (*mgmt_pb.ListOrgMemberRolesResponse, error) {
iam, err := s.query.IAMByID(ctx, domain.IAMID)
func (s *Server) ListOrgMemberRoles(ctx context.Context, _ *mgmt_pb.ListOrgMemberRolesRequest) (*mgmt_pb.ListOrgMemberRolesResponse, error) {
iam, err := s.query.IAM(ctx)
if err != nil {
return nil, err
}

View File

@ -21,6 +21,10 @@ func UserAgentIDFromCtx(ctx context.Context) (string, bool) {
return userAgentID, ok
}
func InstanceIDFromCtx(ctx context.Context) string {
return "" //TODO: implement
}
type UserAgent struct {
ID string
}

View File

@ -10,6 +10,7 @@ import (
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query"
@ -45,7 +46,8 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
}
resp, err := o.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
instanceID := authz.GetInstance(ctx).ID
resp, err := o.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID, instanceID)
if err != nil {
return nil, err
}
@ -55,7 +57,9 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
resp, err := o.repo.AuthRequestByCode(ctx, code)
instanceID := authz.GetInstance(ctx).ID
resp, err := o.repo.AuthRequestByCode(ctx, code, instanceID)
if err != nil {
return nil, err
}
@ -69,13 +73,16 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro
if !ok {
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
}
return o.repo.SaveAuthCode(ctx, id, code, userAgentID)
instanceID := authz.GetInstance(ctx).ID
return o.repo.SaveAuthCode(ctx, id, code, userAgentID, instanceID)
}
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return o.repo.DeleteAuthRequest(ctx, id)
instanceID := authz.GetInstance(ctx).ID
return o.repo.DeleteAuthRequest(ctx, id, instanceID)
}
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {

View File

@ -10,6 +10,7 @@ import (
"github.com/caos/oidc/pkg/op"
"golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/authz"
http_utils "github.com/caos/zitadel/internal/api/http"
model2 "github.com/caos/zitadel/internal/auth_request/model"
"github.com/caos/zitadel/internal/domain"
@ -132,6 +133,7 @@ func CreateAuthRequestToBusiness(ctx context.Context, authReq *oidc.AuthRequest,
SelectedIDPConfigID: GetSelectedIDPIDFromScopes(authReq.Scopes),
MaxAuthAge: MaxAgeToBusiness(authReq.MaxAge),
UserID: userID,
InstanceID: authz.GetInstance(ctx).ID,
Request: &domain.AuthRequestOIDC{
Scopes: authReq.Scopes,
ResponseType: ResponseTypeToBusiness(authReq.ResponseType),

View File

@ -5,6 +5,7 @@ import (
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
)
@ -19,7 +20,8 @@ func (l *Login) getAuthRequest(r *http.Request) (*domain.AuthRequest, error) {
return nil, nil
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
return l.authRepo.AuthRequestByID(r.Context(), authRequestID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
return l.authRepo.AuthRequestByID(r.Context(), authRequestID, userAgentID, instanceID)
}
func (l *Login) getAuthRequestAndParseData(r *http.Request, data interface{}) (*domain.AuthRequest, error) {

View File

@ -16,7 +16,7 @@ func (l *Login) customExternalUserMapping(ctx context.Context, user *domain.Exte
resourceOwner = config.AggregateID
}
if resourceOwner == domain.IAMID {
iam, err := l.query.IAMByID(ctx, domain.IAMID)
iam, err := l.query.IAM(ctx)
if err != nil {
return nil, err
}

View File

@ -11,6 +11,8 @@ import (
"github.com/caos/oidc/pkg/oidc"
"golang.org/x/oauth2"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/crypto"
"github.com/caos/zitadel/internal/domain"
@ -87,7 +89,8 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID, instanceID)
if err != nil {
l.renderLogin(w, r, authReq, err)
return
@ -139,7 +142,8 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID, instanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -198,12 +202,13 @@ func (l *Login) handleExternalUserAuthenticated(w http.ResponseWriter, r *http.R
return
}
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, userAgentID, externalUser, domain.BrowserInfoFromRequest(r))
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, userAgentID, instanceID, externalUser, domain.BrowserInfoFromRequest(r))
if err != nil {
if errors.IsNotFound(err) {
err = nil
}
iam, err := l.query.IAMByID(r.Context(), domain.IAMID)
iam, err := l.query.IAM(r.Context())
if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
return
@ -226,7 +231,7 @@ func (l *Login) handleExternalUserAuthenticated(w http.ResponseWriter, r *http.R
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIAMPolicy, human, idpLinking, err)
return
}
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID)
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID, instanceID)
if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIAMPolicy, human, idpLinking, err)
return
@ -235,7 +240,7 @@ func (l *Login) handleExternalUserAuthenticated(w http.ResponseWriter, r *http.R
return
}
if len(externalUser.Metadatas) > 0 {
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID)
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID, instanceID)
if err != nil {
return
}
@ -254,7 +259,7 @@ func (l *Login) renderExternalNotFoundOption(w http.ResponseWriter, r *http.Requ
errID, errMessage = l.getErrorMessage(r, err)
}
if orgIAMPolicy == nil {
iam, err = l.query.IAMByID(r.Context(), domain.IAMID)
iam, err = l.query.IAM(r.Context())
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -324,7 +329,8 @@ func (l *Login) handleExternalNotFoundOptionCheck(w http.ResponseWriter, r *http
return
} else if data.ResetLinking {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.ResetLinkingUsers(r.Context(), authReq.ID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.ResetLinkingUsers(r.Context(), authReq.ID, userAgentID, instanceID)
if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
}
@ -335,7 +341,7 @@ func (l *Login) handleExternalNotFoundOptionCheck(w http.ResponseWriter, r *http
}
func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) {
iam, err := l.query.IAMByID(r.Context(), domain.IAMID)
iam, err := l.query.IAM(r.Context())
if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
return
@ -362,6 +368,7 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
instanceID := authz.GetInstance(r.Context()).ID
if len(authReq.LinkingUsers) == 0 {
l.renderError(w, r, authReq, caos_errors.ThrowPreconditionFailed(nil, "LOGIN-asfg3", "Errors.ExternalIDP.NoExternalUserData"))
return
@ -373,12 +380,12 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIamPolicy, nil, nil, err)
return
}
err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, memberRoles, authReq.ID, userAgentID, resourceOwner, metadata, domain.BrowserInfoFromRequest(r))
err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, memberRoles, authReq.ID, userAgentID, resourceOwner, instanceID, metadata, domain.BrowserInfoFromRequest(r))
if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIamPolicy, user, externalIDP, err)
return
}
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID)
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, instanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return

View File

@ -8,6 +8,7 @@ import (
"github.com/caos/oidc/pkg/oidc"
"golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain"
iam_model "github.com/caos/zitadel/internal/iam/model"
@ -67,7 +68,8 @@ func (l *Login) handleExternalRegister(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID, instanceID)
if err != nil {
l.renderLogin(w, r, authReq, err)
return
@ -87,7 +89,8 @@ func (l *Login) handleExternalRegisterCallback(w http.ResponseWriter, r *http.Re
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID, instanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -111,7 +114,7 @@ func (l *Login) handleExternalRegisterCallback(w http.ResponseWriter, r *http.Re
}
func (l *Login) handleExternalUserRegister(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, idpConfig *iam_model.IDPConfigView, userAgentID string, tokens *oidc.Tokens) {
iam, err := l.query.IAMByID(r.Context(), domain.IAMID)
iam, err := l.query.IAM(r.Context())
if err != nil {
l.renderRegisterOption(w, r, authReq, err)
return
@ -204,7 +207,7 @@ func (l *Login) handleExternalRegisterCheck(w http.ResponseWriter, r *http.Reque
return
}
iam, err := l.query.IAMByID(r.Context(), domain.IAMID)
iam, err := l.query.IAM(r.Context())
if err != nil {
l.renderRegisterOption(w, r, authReq, err)
return

View File

@ -12,6 +12,7 @@ import (
"github.com/caos/oidc/pkg/client/rp"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/zitadel/internal/api/authz"
http_util "github.com/caos/zitadel/internal/api/http"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
@ -44,7 +45,8 @@ func (l *Login) handleJWTRequest(w http.ResponseWriter, r *http.Request) {
l.renderError(w, r, nil, err)
return
}
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID, instanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -82,13 +84,13 @@ func (l *Login) handleJWTExtraction(w http.ResponseWriter, r *http.Request, auth
return
}
metadata := externalUser.Metadatas
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, authReq.AgentID, externalUser, domain.BrowserInfoFromRequest(r))
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID, externalUser, domain.BrowserInfoFromRequest(r))
if err != nil {
l.jwtExtractionUserNotFound(w, r, authReq, idpConfig, tokens, err)
return
}
if len(metadata) > 0 {
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID)
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -115,7 +117,7 @@ func (l *Login) jwtExtractionUserNotFound(w http.ResponseWriter, r *http.Request
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
return
}
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID)
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -133,12 +135,12 @@ func (l *Login) jwtExtractionUserNotFound(w http.ResponseWriter, r *http.Request
l.renderError(w, r, authReq, err)
return
}
err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, nil, authReq.ID, authReq.AgentID, resourceOwner, metadata, domain.BrowserInfoFromRequest(r))
err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, nil, authReq.ID, authReq.AgentID, resourceOwner, authReq.InstanceID, metadata, domain.BrowserInfoFromRequest(r))
if err != nil {
l.renderError(w, r, authReq, err)
return
}
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID)
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return
@ -207,7 +209,8 @@ func (l *Login) handleJWTCallback(w http.ResponseWriter, r *http.Request) {
l.renderError(w, r, nil, err)
return
}
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID, instanceID)
if err != nil {
l.renderError(w, r, authReq, err)
return

View File

@ -3,6 +3,7 @@ package login
import (
"net/http"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain"
)
@ -13,7 +14,8 @@ const (
func (l *Login) linkUsers(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.LinkExternalUsers(setContext(r.Context(), authReq.UserOrgID), authReq.ID, userAgentID, domain.BrowserInfoFromRequest(r))
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.LinkExternalUsers(setContext(r.Context(), authReq.UserOrgID), authReq.ID, userAgentID, instanceID, domain.BrowserInfoFromRequest(r))
l.renderLinkUsersDone(w, r, authReq, err)
}

View File

@ -3,6 +3,7 @@ package login
import (
"net/http"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
@ -59,8 +60,9 @@ func (l *Login) handleLoginNameCheck(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
instanceID := authz.GetInstance(r.Context()).ID
loginName := data.LoginName
err = l.authRepo.CheckLoginName(r.Context(), authReq.ID, loginName, userAgentID)
err = l.authRepo.CheckLoginName(r.Context(), authReq.ID, loginName, userAgentID, instanceID)
if err != nil {
l.renderLogin(w, r, authReq, err)
return

View File

@ -3,6 +3,7 @@ package login
import (
"net/http"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain"
)
@ -35,7 +36,8 @@ func (l *Login) handleMFAVerify(w http.ResponseWriter, r *http.Request) {
}
if data.MFAType == domain.MFATypeOTP {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.VerifyMFAOTP(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Code, userAgentID, domain.BrowserInfoFromRequest(r))
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.VerifyMFAOTP(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Code, userAgentID, instanceID, domain.BrowserInfoFromRequest(r))
if err != nil {
l.renderMFAVerifySelected(w, r, authReq, step, domain.MFATypeOTP, err)
return

View File

@ -6,6 +6,7 @@ import (
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
)
@ -29,7 +30,8 @@ func (l *Login) renderU2FVerification(w http.ResponseWriter, r *http.Request, au
var webAuthNLogin *domain.WebAuthNLogin
if err == nil {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
webAuthNLogin, err = l.authRepo.BeginMFAU2FLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
webAuthNLogin, err = l.authRepo.BeginMFAU2FLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, instanceID)
}
if err != nil {
errID, errMessage = l.getErrorMessage(r, err)
@ -70,7 +72,8 @@ func (l *Login) handleU2FVerification(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.VerifyMFAU2F(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, credData, domain.BrowserInfoFromRequest(r))
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.VerifyMFAU2F(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, instanceID, credData, domain.BrowserInfoFromRequest(r))
if err != nil {
l.renderU2FVerification(w, r, authReq, step.MFAProviders, err)
return

View File

@ -4,8 +4,6 @@ import (
"net/http"
"github.com/caos/zitadel/internal/domain"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
)
const (
@ -40,8 +38,7 @@ func (l *Login) handlePasswordCheck(w http.ResponseWriter, r *http.Request) {
l.renderError(w, r, authReq, err)
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.VerifyPassword(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Password, userAgentID, domain.BrowserInfoFromRequest(r))
err = l.authRepo.VerifyPassword(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Password, authReq.AgentID, authReq.InstanceID, domain.BrowserInfoFromRequest(r))
if err != nil {
l.renderPassword(w, r, authReq, err)
return

View File

@ -5,8 +5,6 @@ import (
"net/http"
"github.com/caos/zitadel/internal/domain"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
)
const (
@ -27,8 +25,7 @@ func (l *Login) renderPasswordlessVerification(w http.ResponseWriter, r *http.Re
var errID, errMessage, credentialData string
var webAuthNLogin *domain.WebAuthNLogin
if err == nil {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
webAuthNLogin, err = l.authRepo.BeginPasswordlessLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID)
webAuthNLogin, err = l.authRepo.BeginPasswordlessLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, authReq.AgentID, authReq.InstanceID)
}
if err != nil {
errID, errMessage = l.getErrorMessage(r, err)
@ -65,8 +62,7 @@ func (l *Login) handlePasswordlessVerification(w http.ResponseWriter, r *http.Re
l.renderPasswordlessVerification(w, r, authReq, formData.PasswordLogin, err)
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.VerifyPasswordless(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, credData, domain.BrowserInfoFromRequest(r))
err = l.authRepo.VerifyPasswordless(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, authReq.AgentID, authReq.InstanceID, credData, domain.BrowserInfoFromRequest(r))
if err != nil {
l.renderPasswordlessVerification(w, r, authReq, formData.PasswordLogin, err)
return

View File

@ -5,6 +5,7 @@ import (
"golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain"
caos_errs "github.com/caos/zitadel/internal/errors"
@ -61,7 +62,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
l.renderRegister(w, r, authRequest, data, err)
return
}
iam, err := l.query.IAMByID(r.Context(), domain.IAMID)
iam, err := l.query.IAM(r.Context())
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return
@ -94,7 +95,8 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, user.AggregateID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, user.AggregateID, userAgentID, instanceID)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return
@ -125,7 +127,7 @@ func (l *Login) renderRegister(w http.ResponseWriter, r *http.Request, authReque
}
if resourceOwner == "" {
iam, err := l.query.IAMByID(r.Context(), domain.IAMID)
iam, err := l.query.IAM(r.Context())
if err != nil {
l.renderRegister(w, r, authRequest, formData, err)
return

View File

@ -224,8 +224,7 @@ func (l *Login) renderNextStep(w http.ResponseWriter, r *http.Request, authReq *
l.renderInternalError(w, r, nil, caos_errs.ThrowInvalidArgument(nil, "LOGIN-Df3f2", "Errors.AuthRequest.NotFound"))
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
authReq, err := l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID)
authReq, err := l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil {
l.renderInternalError(w, r, authReq, err)
return

View File

@ -5,6 +5,7 @@ import (
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
)
@ -38,7 +39,8 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID)
instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, instanceID)
if err != nil {
l.renderError(w, r, authSession, err)
return

View File

@ -8,30 +8,30 @@ import (
type AuthRequestRepository interface {
CreateAuthRequest(ctx context.Context, request *domain.AuthRequest) (*domain.AuthRequest, error)
AuthRequestByID(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error)
AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error)
AuthRequestByCode(ctx context.Context, code string) (*domain.AuthRequest, error)
SaveAuthCode(ctx context.Context, id, code, userAgentID string) error
DeleteAuthRequest(ctx context.Context, id string) error
AuthRequestByID(ctx context.Context, id, userAgentID, instanceID string) (*domain.AuthRequest, error)
AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID, instanceID string) (*domain.AuthRequest, error)
AuthRequestByCode(ctx context.Context, code, instanceID string) (*domain.AuthRequest, error)
SaveAuthCode(ctx context.Context, id, code, userAgentID, instanceID string) error
DeleteAuthRequest(ctx context.Context, id, instanceID string) error
CheckLoginName(ctx context.Context, id, loginName, userAgentID string) error
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo) error
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error
SelectUser(ctx context.Context, id, userID, userAgentID string) error
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string) error
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error
CheckLoginName(ctx context.Context, id, loginName, userAgentID, instanceID string) error
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, user *domain.ExternalUser, info *domain.BrowserInfo) error
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, user *domain.ExternalUser) error
SelectUser(ctx context.Context, id, userID, userAgentID, instanceID string) error
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID, instanceID string) error
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID, instanceID string, info *domain.BrowserInfo) error
VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID string, info *domain.BrowserInfo) error
BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (*domain.WebAuthNLogin, error)
VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) error
VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID, instanceID string, info *domain.BrowserInfo) error
BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (*domain.WebAuthNLogin, error)
VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) error
BeginPasswordlessSetup(ctx context.Context, userID, resourceOwner string, preferredPlatformType domain.AuthenticatorAttachment) (login *domain.WebAuthNToken, err error)
VerifyPasswordlessSetup(ctx context.Context, userID, resourceOwner, userAgentID, tokenName string, credentialData []byte) (err error)
BeginPasswordlessInitCodeSetup(ctx context.Context, userID, resourceOwner, codeID, verificationCode string, preferredPlatformType domain.AuthenticatorAttachment) (login *domain.WebAuthNToken, err error)
VerifyPasswordlessInitCodeSetup(ctx context.Context, userID, resourceOwner, userAgentID, tokenName, codeID, verificationCode string, credentialData []byte) (err error)
BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (*domain.WebAuthNLogin, error)
VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) error
BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (*domain.WebAuthNLogin, error)
VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) error
LinkExternalUsers(ctx context.Context, authReqID, userAgentID string, info *domain.BrowserInfo) error
AutoRegisterExternalUser(ctx context.Context, user *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner string, metadatas []*domain.Metadata, info *domain.BrowserInfo) error
ResetLinkingUsers(ctx context.Context, authReqID, userAgentID string) error
LinkExternalUsers(ctx context.Context, authReqID, userAgentID, instanceID string, info *domain.BrowserInfo) error
AutoRegisterExternalUser(ctx context.Context, user *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner, instanceID string, metadatas []*domain.Metadata, info *domain.BrowserInfo) error
ResetLinkingUsers(ctx context.Context, authReqID, userAgentID, instanceID string) error
}

View File

@ -156,22 +156,22 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom
return request, nil
}
func (repo *AuthRequestRepo) AuthRequestByID(ctx context.Context, id, userAgentID string) (_ *domain.AuthRequest, err error) {
func (repo *AuthRequestRepo) AuthRequestByID(ctx context.Context, id, userAgentID, instanceID string) (_ *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return repo.getAuthRequestNextSteps(ctx, id, userAgentID, false)
return repo.getAuthRequestNextSteps(ctx, id, userAgentID, instanceID, false)
}
func (repo *AuthRequestRepo) AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID string) (_ *domain.AuthRequest, err error) {
func (repo *AuthRequestRepo) AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID, instanceID string) (_ *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return repo.getAuthRequestNextSteps(ctx, id, userAgentID, true)
return repo.getAuthRequestNextSteps(ctx, id, userAgentID, instanceID, true)
}
func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code, userAgentID string) (err error) {
func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, id, userAgentID)
request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil {
return err
}
@ -179,10 +179,10 @@ func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code, userAge
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code string) (_ *domain.AuthRequest, err error) {
func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code, instanceID string) (_ *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.AuthRequests.GetAuthRequestByCode(ctx, code)
request, err := repo.AuthRequests.GetAuthRequestByCode(ctx, code, instanceID)
if err != nil {
return nil, err
}
@ -198,16 +198,16 @@ func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code string)
return request, nil
}
func (repo *AuthRequestRepo) DeleteAuthRequest(ctx context.Context, id string) (err error) {
func (repo *AuthRequestRepo) DeleteAuthRequest(ctx context.Context, id, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return repo.AuthRequests.DeleteAuthRequest(ctx, id)
return repo.AuthRequests.DeleteAuthRequest(ctx, id, instanceID)
}
func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName, userAgentID string) (err error) {
func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, id, userAgentID)
request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil {
return err
}
@ -218,10 +218,10 @@ func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName,
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string) (err error) {
func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil {
return err
}
@ -232,10 +232,10 @@ func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, i
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, externalUser *domain.ExternalUser, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, externalUser *domain.ExternalUser, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil {
return err
}
@ -257,10 +257,10 @@ func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReq
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, externalUser *domain.ExternalUser) (err error) {
func (repo *AuthRequestRepo) SetExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, externalUser *domain.ExternalUser) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil {
return err
}
@ -277,10 +277,10 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID, userAgentID string) (err error) {
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, id, userAgentID)
request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil {
return err
}
@ -299,10 +299,10 @@ func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID, userAge
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID, instanceID string, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, id, userAgentID, userID)
request, err := repo.getAuthRequestEnsureUser(ctx, id, userAgentID, userID, instanceID)
if err != nil {
return err
}
@ -328,31 +328,31 @@ func lockoutPolicyToDomain(policy *query.LockoutPolicy) *domain.LockoutPolicy {
}
}
func (repo *AuthRequestRepo) VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID string, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID, instanceID string, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID)
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil {
return err
}
return repo.Command.HumanCheckMFAOTP(ctx, userID, code, resourceOwner, request.WithCurrentInfo(info))
}
func (repo *AuthRequestRepo) BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (login *domain.WebAuthNLogin, err error) {
func (repo *AuthRequestRepo) BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (login *domain.WebAuthNLogin, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID)
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil {
return nil, err
}
return repo.Command.HumanBeginU2FLogin(ctx, userID, resourceOwner, request, true)
}
func (repo *AuthRequestRepo) VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID)
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil {
return err
}
@ -393,30 +393,30 @@ func (repo *AuthRequestRepo) VerifyPasswordlessInitCodeSetup(ctx context.Context
return err
}
func (repo *AuthRequestRepo) BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (login *domain.WebAuthNLogin, err error) {
func (repo *AuthRequestRepo) BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (login *domain.WebAuthNLogin, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID)
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil {
return nil, err
}
return repo.Command.HumanBeginPasswordlessLogin(ctx, userID, resourceOwner, request, true)
}
func (repo *AuthRequestRepo) VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID)
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil {
return err
}
return repo.Command.HumanFinishPasswordlessLogin(ctx, userID, resourceOwner, credentialData, request, true)
}
func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, userAgentID string, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, userAgentID, instanceID string, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil {
return err
}
@ -432,8 +432,8 @@ func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, u
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, userAgentID string) error {
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, userAgentID, instanceID string) error {
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil {
return err
}
@ -442,10 +442,10 @@ func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, u
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner string, metadatas []*domain.Metadata, info *domain.BrowserInfo) (err error) {
func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner, instanceID string, metadatas []*domain.Metadata, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil {
return err
}
@ -478,8 +478,8 @@ func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, regis
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) getAuthRequestNextSteps(ctx context.Context, id, userAgentID string, checkLoggedIn bool) (*domain.AuthRequest, error) {
request, err := repo.getAuthRequest(ctx, id, userAgentID)
func (repo *AuthRequestRepo) getAuthRequestNextSteps(ctx context.Context, id, userAgentID, instanceID string, checkLoggedIn bool) (*domain.AuthRequest, error) {
request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil {
return nil, err
}
@ -491,8 +491,8 @@ func (repo *AuthRequestRepo) getAuthRequestNextSteps(ctx context.Context, id, us
return request, nil
}
func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authRequestID, userAgentID, userID string) (*domain.AuthRequest, error) {
request, err := repo.getAuthRequest(ctx, authRequestID, userAgentID)
func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authRequestID, userAgentID, userID, instanceID string) (*domain.AuthRequest, error) {
request, err := repo.getAuthRequest(ctx, authRequestID, userAgentID, instanceID)
if err != nil {
return nil, err
}
@ -506,8 +506,8 @@ func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authR
return request, nil
}
func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error) {
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id, userAgentID, instanceID string) (*domain.AuthRequest, error) {
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id, instanceID)
if err != nil {
return nil, err
}

View File

@ -3,7 +3,7 @@ package handler
import (
"github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore/v1"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler"
@ -76,6 +76,7 @@ func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) {
case model.ProjectAdded:
mapping.OrgID = event.ResourceOwner
mapping.ProjectID = event.AggregateID
mapping.InstanceID = event.InstanceID
case model.ProjectRemoved:
err := p.view.DeleteOrgProjectMappingsByProjectID(event.AggregateID)
if err == nil {
@ -87,6 +88,7 @@ func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) {
mapping.OrgID = projectGrant.GrantedOrgID
mapping.ProjectID = event.AggregateID
mapping.ProjectGrantID = projectGrant.GrantID
mapping.InstanceID = projectGrant.InstanceID
case model.ProjectGrantRemoved:
projectGrant := new(view_model.ProjectGrant)
projectGrant.SetData(event)

View File

@ -2,9 +2,10 @@ package handler
import (
"github.com/caos/logging"
req_model "github.com/caos/zitadel/internal/auth_request/model"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler"
@ -104,6 +105,7 @@ func (u *UserSession) Reduce(event *models.Event) (err error) {
UserAgentID: eventData.UserAgentID,
UserID: event.AggregateID,
State: int32(req_model.UserSessionStateActive),
InstanceID: event.InstanceID,
}
}
return u.updateSession(session, event)

View File

@ -26,38 +26,38 @@ func (c *AuthRequestCache) Health(ctx context.Context) error {
return c.client.PingContext(ctx)
}
func (c *AuthRequestCache) GetAuthRequestByID(_ context.Context, id string) (*domain.AuthRequest, error) {
return c.getAuthRequest("id", id)
func (c *AuthRequestCache) GetAuthRequestByID(_ context.Context, id, instanceID string) (*domain.AuthRequest, error) {
return c.getAuthRequest("id", id, instanceID)
}
func (c *AuthRequestCache) GetAuthRequestByCode(_ context.Context, code string) (*domain.AuthRequest, error) {
return c.getAuthRequest("code", code)
func (c *AuthRequestCache) GetAuthRequestByCode(_ context.Context, code, instanceID string) (*domain.AuthRequest, error) {
return c.getAuthRequest("code", code, instanceID)
}
func (c *AuthRequestCache) SaveAuthRequest(_ context.Context, request *domain.AuthRequest) error {
return c.saveAuthRequest(request, "INSERT INTO auth.auth_requests (id, request, creation_date, change_date, request_type) VALUES($1, $2, $3, $3, $4)", request.CreationDate, request.Request.Type())
return c.saveAuthRequest(request, "INSERT INTO auth.auth_requests (id, request, instance_id, creation_date, change_date, request_type) VALUES($1, $2, $3, $3, $4, $5)", request.CreationDate, request.InstanceID, request.Request.Type())
}
func (c *AuthRequestCache) UpdateAuthRequest(_ context.Context, request *domain.AuthRequest) error {
if request.ChangeDate.IsZero() {
request.ChangeDate = time.Now()
}
return c.saveAuthRequest(request, "UPDATE auth.auth_requests SET request = $2, change_date = $3, code = $4 WHERE id = $1", request.ChangeDate, request.Code)
return c.saveAuthRequest(request, "UPDATE auth.auth_requests SET request = $2, instance_id = $3 change_date = $4, code = $5 WHERE id = $1", request.ChangeDate, request.InstanceID, request.Code)
}
func (c *AuthRequestCache) DeleteAuthRequest(_ context.Context, id string) error {
_, err := c.client.Exec("DELETE FROM auth.auth_requests WHERE id = $1", id)
func (c *AuthRequestCache) DeleteAuthRequest(_ context.Context, id, instanceID string) error {
_, err := c.client.Exec("DELETE FROM auth.auth_requests WHERE instance = $1 and id = $2", instanceID, id)
if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-dsHw3", "unable to delete auth request")
}
return nil
}
func (c *AuthRequestCache) getAuthRequest(key, value string) (*domain.AuthRequest, error) {
func (c *AuthRequestCache) getAuthRequest(key, value, instanceID string) (*domain.AuthRequest, error) {
var b []byte
var requestType domain.AuthRequestType
query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE %s = $1", key)
err := c.client.QueryRow(query, value).Scan(&b, &requestType)
query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE instance = $1 and %s = $2", key)
err := c.client.QueryRow(query, instanceID, value).Scan(&b, &requestType)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "Errors.AuthRequest.NotFound")
@ -74,7 +74,7 @@ func (c *AuthRequestCache) getAuthRequest(key, value string) (*domain.AuthReques
return request, nil
}
func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query string, date time.Time, param interface{}) error {
func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query string, date time.Time, instanceID string, param interface{}) error {
b, err := json.Marshal(request)
if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-os0GH", "Errors.Internal")
@ -83,7 +83,7 @@ func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query st
if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-su3GK", "Errors.Internal")
}
_, err = stmt.Exec(request.ID, b, date, param)
_, err = stmt.Exec(request.ID, b, date, instanceID, param)
if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-sj8iS", "Errors.Internal")
}

View File

@ -6,79 +6,80 @@ package mock
import (
context "context"
reflect "reflect"
domain "github.com/caos/zitadel/internal/domain"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockAuthRequestCache is a mock of AuthRequestCache interface
// MockAuthRequestCache is a mock of AuthRequestCache interface.
type MockAuthRequestCache struct {
ctrl *gomock.Controller
recorder *MockAuthRequestCacheMockRecorder
}
// MockAuthRequestCacheMockRecorder is the mock recorder for MockAuthRequestCache
// MockAuthRequestCacheMockRecorder is the mock recorder for MockAuthRequestCache.
type MockAuthRequestCacheMockRecorder struct {
mock *MockAuthRequestCache
}
// NewMockAuthRequestCache creates a new mock instance
// NewMockAuthRequestCache creates a new mock instance.
func NewMockAuthRequestCache(ctrl *gomock.Controller) *MockAuthRequestCache {
mock := &MockAuthRequestCache{ctrl: ctrl}
mock.recorder = &MockAuthRequestCacheMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder {
return m.recorder
}
// DeleteAuthRequest mocks base method
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
// DeleteAuthRequest mocks base method.
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1, arg2)
}
// GetAuthRequestByCode mocks base method
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
// GetAuthRequestByCode mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1)
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1, arg2)
ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
// GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1, arg2)
}
// GetAuthRequestByID mocks base method
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
// GetAuthRequestByID mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1)
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1, arg2)
ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1, arg2)
}
// Health mocks base method
// Health mocks base method.
func (m *MockAuthRequestCache) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
@ -86,13 +87,13 @@ func (m *MockAuthRequestCache) Health(arg0 context.Context) error {
return ret0
}
// Health indicates an expected call of Health
// Health indicates an expected call of Health.
func (mr *MockAuthRequestCacheMockRecorder) Health(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockAuthRequestCache)(nil).Health), arg0)
}
// SaveAuthRequest mocks base method
// SaveAuthRequest mocks base method.
func (m *MockAuthRequestCache) SaveAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveAuthRequest", arg0, arg1)
@ -100,13 +101,13 @@ func (m *MockAuthRequestCache) SaveAuthRequest(arg0 context.Context, arg1 *domai
return ret0
}
// SaveAuthRequest indicates an expected call of SaveAuthRequest
// SaveAuthRequest indicates an expected call of SaveAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) SaveAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).SaveAuthRequest), arg0, arg1)
}
// UpdateAuthRequest mocks base method
// UpdateAuthRequest mocks base method.
func (m *MockAuthRequestCache) UpdateAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAuthRequest", arg0, arg1)
@ -114,7 +115,7 @@ func (m *MockAuthRequestCache) UpdateAuthRequest(arg0 context.Context, arg1 *dom
return ret0
}
// UpdateAuthRequest indicates an expected call of UpdateAuthRequest
// UpdateAuthRequest indicates an expected call of UpdateAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) UpdateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).UpdateAuthRequest), arg0, arg1)

View File

@ -2,15 +2,16 @@ package repository
import (
"context"
"github.com/caos/zitadel/internal/domain"
)
type AuthRequestCache interface {
Health(ctx context.Context) error
GetAuthRequestByID(ctx context.Context, id string) (*domain.AuthRequest, error)
GetAuthRequestByCode(ctx context.Context, code string) (*domain.AuthRequest, error)
GetAuthRequestByID(ctx context.Context, id, instanceID string) (*domain.AuthRequest, error)
GetAuthRequestByCode(ctx context.Context, code, instanceID string) (*domain.AuthRequest, error)
SaveAuthRequest(ctx context.Context, request *domain.AuthRequest) error
UpdateAuthRequest(ctx context.Context, request *domain.AuthRequest) error
DeleteAuthRequest(ctx context.Context, id string) error
DeleteAuthRequest(ctx context.Context, id, instanceID string) error
}

View File

@ -236,7 +236,7 @@ func (repo *TokenVerifierRepo) VerifierClientID(ctx context.Context, appName str
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
iam, err := repo.Query.IAMByID(ctx, domain.IAMID)
iam, err := repo.Query.IAM(ctx)
if err != nil {
return "", "", err
}

View File

@ -28,6 +28,7 @@ func (repo *UserMembershipRepo) SearchMyMemberships(ctx context.Context) ([]*aut
func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*user_view_model.UserMembershipView, error) {
ctxData := authz.GetCtxData(ctx)
instance := authz.GetInstance(ctx)
orgMemberships, orgCount, err := repo.View.SearchUserMemberships(&user_model.UserMembershipSearchRequest{
Queries: []*user_model.UserMembershipSearchQuery{
{
@ -40,6 +41,11 @@ func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*u
Method: domain.SearchMethodEquals,
Value: ctxData.OrgID,
},
{
Key: user_model.UserMembershipSearchKeyInstanceID,
Method: domain.SearchMethodEquals,
Value: instance.ID,
},
},
})
if err != nil {
@ -57,6 +63,11 @@ func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*u
Method: domain.SearchMethodEquals,
Value: domain.IAMID,
},
{
Key: user_model.UserMembershipSearchKeyInstanceID,
Method: domain.SearchMethodEquals,
Value: instance.ID,
},
},
})
if err != nil {

View File

@ -30,8 +30,6 @@ func (h *handler) Eventstore() v1.Eventstore {
func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es v1.Eventstore, systemDefaults sd.SystemDefaults) []query.Handler {
return []query.Handler{
newUserGrant(
handler{view, bulkLimit, configs.cycleDuration("UserGrants"), errorCount, es}),
newUserMembership(
handler{view, bulkLimit, configs.cycleDuration("UserMemberships"), errorCount, es}),
}

View File

@ -1,313 +0,0 @@
package handler
import (
"context"
"strings"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
es_sdk "github.com/caos/zitadel/internal/eventstore/v1/sdk"
iam_model "github.com/caos/zitadel/internal/iam/model"
iam_view "github.com/caos/zitadel/internal/iam/repository/view"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
caos_errs "github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler"
iam_es_model "github.com/caos/zitadel/internal/iam/repository/eventsourcing/model"
org_es_model "github.com/caos/zitadel/internal/org/repository/eventsourcing/model"
proj_es_model "github.com/caos/zitadel/internal/project/repository/eventsourcing/model"
view_model "github.com/caos/zitadel/internal/usergrant/repository/view/model"
)
const (
userGrantTable = "authz.user_grants"
)
type UserGrant struct {
handler
iamProjectID string
subscription *v1.Subscription
}
func newUserGrant(
handler handler,
) *UserGrant {
h := &UserGrant{
handler: handler,
}
h.subscribe()
return h
}
func (k *UserGrant) subscribe() {
k.subscription = k.es.Subscribe(k.AggregateTypes()...)
go func() {
for event := range k.subscription.Events {
query.ReduceEvent(k, event)
}
}()
}
func (u *UserGrant) ViewModel() string {
return userGrantTable
}
func (u *UserGrant) Subscription() *v1.Subscription {
return u.subscription
}
func (_ *UserGrant) AggregateTypes() []es_models.AggregateType {
return []es_models.AggregateType{iam_es_model.IAMAggregate, org_es_model.OrgAggregate, proj_es_model.ProjectAggregate}
}
func (u *UserGrant) CurrentSequence() (uint64, error) {
sequence, err := u.view.GetLatestUserGrantSequence()
if err != nil {
return 0, err
}
return sequence.CurrentSequence, nil
}
func (u *UserGrant) EventQuery() (*es_models.SearchQuery, error) {
if u.iamProjectID == "" {
err := u.setIamProjectID()
if err != nil {
return nil, err
}
}
sequence, err := u.view.GetLatestUserGrantSequence()
if err != nil {
return nil, err
}
return es_models.NewSearchQuery().
AggregateTypeFilter(iam_es_model.IAMAggregate, org_es_model.OrgAggregate, proj_es_model.ProjectAggregate).
LatestSequenceFilter(sequence.CurrentSequence), nil
}
func (u *UserGrant) Reduce(event *es_models.Event) (err error) {
switch event.AggregateType {
case proj_es_model.ProjectAggregate:
err = u.processProject(event)
case iam_es_model.IAMAggregate:
err = u.processIAMMember(event, "IAM", false)
case org_es_model.OrgAggregate:
return u.processOrg(event)
}
return err
}
func (u *UserGrant) processProject(event *es_models.Event) (err error) {
switch event.Type {
case proj_es_model.ProjectMemberAdded, proj_es_model.ProjectMemberChanged,
proj_es_model.ProjectMemberRemoved, proj_es_model.ProjectMemberCascadeRemoved:
member := new(proj_es_model.ProjectMember)
err := member.SetData(event)
if err != nil {
return err
}
return u.processMember(event, "PROJECT", event.AggregateID, member.UserID, member.Roles)
case proj_es_model.ProjectGrantMemberAdded, proj_es_model.ProjectGrantMemberChanged,
proj_es_model.ProjectGrantMemberRemoved,
proj_es_model.ProjectGrantMemberCascadeRemoved:
member := new(proj_es_model.ProjectGrantMember)
err := member.SetData(event)
if err != nil {
return err
}
return u.processMember(event, "PROJECT_GRANT", member.GrantID, member.UserID, member.Roles)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func (u *UserGrant) processOrg(event *es_models.Event) (err error) {
switch event.Type {
case org_es_model.OrgMemberAdded, org_es_model.OrgMemberChanged,
org_es_model.OrgMemberRemoved, org_es_model.OrgMemberCascadeRemoved:
member := new(org_es_model.OrgMember)
err := member.SetData(event)
if err != nil {
return err
}
return u.processMember(event, "ORG", "", member.UserID, member.Roles)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func (u *UserGrant) processIAMMember(event *es_models.Event, rolePrefix string, suffix bool) error {
member := new(iam_es_model.IAMMember)
switch event.Type {
case iam_es_model.IAMMemberAdded, iam_es_model.IAMMemberChanged:
member.SetData(event)
grant, err := u.view.UserGrantByIDs(domain.IAMID, u.iamProjectID, member.UserID)
if err != nil && !errors.IsNotFound(err) {
return err
}
if errors.IsNotFound(err) {
grant = &view_model.UserGrantView{
ID: u.iamProjectID + member.UserID,
ResourceOwner: domain.IAMID,
OrgName: domain.IAMID,
ProjectID: u.iamProjectID,
UserID: member.UserID,
RoleKeys: member.Roles,
CreationDate: event.CreationDate,
}
if suffix {
grant.RoleKeys = suffixRoles(event.AggregateID, grant.RoleKeys)
}
} else {
newRoles := member.Roles
if grant.RoleKeys != nil {
grant.RoleKeys = mergeExistingRoles(rolePrefix, "", grant.RoleKeys, newRoles)
} else {
grant.RoleKeys = newRoles
}
}
grant.Sequence = event.Sequence
grant.ChangeDate = event.CreationDate
return u.view.PutUserGrant(grant, event)
case iam_es_model.IAMMemberRemoved,
iam_es_model.IAMMemberCascadeRemoved:
member.SetData(event)
grant, err := u.view.UserGrantByIDs(domain.IAMID, u.iamProjectID, member.UserID)
if err != nil {
return err
}
return u.view.DeleteUserGrant(grant.ID, event)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func (u *UserGrant) processMember(event *es_models.Event, rolePrefix, roleSuffix string, userID string, roleKeys []string) error {
switch event.Type {
case org_es_model.OrgMemberAdded, proj_es_model.ProjectMemberAdded, proj_es_model.ProjectGrantMemberAdded,
org_es_model.OrgMemberChanged, proj_es_model.ProjectMemberChanged, proj_es_model.ProjectGrantMemberChanged:
grant, err := u.view.UserGrantByIDs(event.ResourceOwner, u.iamProjectID, userID)
if err != nil && !errors.IsNotFound(err) {
return err
}
if roleSuffix != "" {
roleKeys = suffixRoles(roleSuffix, roleKeys)
}
if errors.IsNotFound(err) {
grant = &view_model.UserGrantView{
ID: u.iamProjectID + event.ResourceOwner + userID,
ResourceOwner: event.ResourceOwner,
ProjectID: u.iamProjectID,
UserID: userID,
RoleKeys: roleKeys,
CreationDate: event.CreationDate,
}
} else {
newRoles := roleKeys
if grant.RoleKeys != nil {
grant.RoleKeys = mergeExistingRoles(rolePrefix, roleSuffix, grant.RoleKeys, newRoles)
} else {
grant.RoleKeys = newRoles
}
}
grant.Sequence = event.Sequence
grant.ChangeDate = event.CreationDate
return u.view.PutUserGrant(grant, event)
case org_es_model.OrgMemberRemoved,
org_es_model.OrgMemberCascadeRemoved,
proj_es_model.ProjectMemberRemoved,
proj_es_model.ProjectMemberCascadeRemoved,
proj_es_model.ProjectGrantMemberRemoved,
proj_es_model.ProjectGrantMemberCascadeRemoved:
grant, err := u.view.UserGrantByIDs(event.ResourceOwner, u.iamProjectID, userID)
if err != nil && !errors.IsNotFound(err) {
return err
}
if errors.IsNotFound(err) {
return u.view.ProcessedUserGrantSequence(event)
}
if roleSuffix != "" {
roleKeys = suffixRoles(roleSuffix, roleKeys)
}
if grant.RoleKeys == nil {
return u.view.ProcessedUserGrantSequence(event)
}
grant.RoleKeys = mergeExistingRoles(rolePrefix, roleSuffix, grant.RoleKeys, nil)
return u.view.PutUserGrant(grant, event)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func suffixRoles(suffix string, roles []string) []string {
suffixedRoles := make([]string, len(roles))
for i := 0; i < len(roles); i++ {
suffixedRoles[i] = roles[i] + ":" + suffix
}
return suffixedRoles
}
func mergeExistingRoles(rolePrefix, suffix string, existingRoles, newRoles []string) []string {
mergedRoles := make([]string, 0)
for _, existingRole := range existingRoles {
if !strings.HasPrefix(existingRole, rolePrefix) {
mergedRoles = append(mergedRoles, existingRole)
continue
}
if suffix != "" && !strings.HasSuffix(existingRole, suffix) {
mergedRoles = append(mergedRoles, existingRole)
}
}
return append(mergedRoles, newRoles...)
}
func (u *UserGrant) setIamProjectID() error {
if u.iamProjectID != "" {
return nil
}
iam, err := u.getIAMByID(context.Background())
if err != nil {
return err
}
if iam.SetUpDone < domain.StepCount-1 {
return caos_errs.ThrowPreconditionFailed(nil, "HANDL-s5DTs", "Setup not done")
}
u.iamProjectID = iam.IAMProjectID
return nil
}
func (u *UserGrant) OnError(event *es_models.Event, err error) error {
logging.LogWithFields("SPOOL-VcVoJ", "id", event.AggregateID).WithError(err).Warn("something went wrong in user grant handler")
return spooler.HandleError(event, err, u.view.GetLatestUserGrantFailedEvent, u.view.ProcessedUserGrantFailedEvent, u.view.ProcessedUserGrantSequence, u.errorCountUntilSkip)
}
func (u *UserGrant) OnSuccess() error {
return spooler.HandleSuccess(u.view.UpdateUserGrantSpoolerRunTimestamp)
}
func (u *UserGrant) getIAMByID(ctx context.Context) (*iam_model.IAM, error) {
query, err := iam_view.IAMByIDQuery(domain.IAMID, 0)
if err != nil {
return nil, err
}
iam := &iam_es_model.IAM{
ObjectRoot: es_models.ObjectRoot{
AggregateID: domain.IAMID,
},
}
err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, iam.AppendEvents, query)
if err != nil && errors.IsNotFound(err) && iam.Sequence == 0 {
return nil, err
}
return iam_es_model.IAMToModel(iam), nil
}

View File

@ -1,70 +0,0 @@
package view
import (
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
grant_model "github.com/caos/zitadel/internal/usergrant/model"
"github.com/caos/zitadel/internal/usergrant/repository/view"
"github.com/caos/zitadel/internal/usergrant/repository/view/model"
"github.com/caos/zitadel/internal/view/repository"
)
const (
userGrantTable = "authz.user_grants"
)
func (v *View) UserGrantByID(grantID string) (*model.UserGrantView, error) {
return view.UserGrantByID(v.Db, userGrantTable, grantID)
}
func (v *View) UserGrantByIDs(resourceOwnerID, projectID, userID string) (*model.UserGrantView, error) {
return view.UserGrantByIDs(v.Db, userGrantTable, resourceOwnerID, projectID, userID)
}
func (v *View) UserGrantsByUserID(userID string) ([]*model.UserGrantView, error) {
return view.UserGrantsByUserID(v.Db, userGrantTable, userID)
}
func (v *View) UserGrantsByProjectID(projectID string) ([]*model.UserGrantView, error) {
return view.UserGrantsByProjectID(v.Db, userGrantTable, projectID)
}
func (v *View) SearchUserGrants(request *grant_model.UserGrantSearchRequest) ([]*model.UserGrantView, uint64, error) {
return view.SearchUserGrants(v.Db, userGrantTable, request)
}
func (v *View) PutUserGrant(grant *model.UserGrantView, event *models.Event) error {
err := view.PutUserGrant(v.Db, userGrantTable, grant)
if err != nil {
return err
}
return v.ProcessedUserGrantSequence(event)
}
func (v *View) DeleteUserGrant(grantID string, event *models.Event) error {
err := view.DeleteUserGrant(v.Db, userGrantTable, grantID)
if err != nil && !errors.IsNotFound(err) {
return err
}
return v.ProcessedUserGrantSequence(event)
}
func (v *View) GetLatestUserGrantSequence() (*repository.CurrentSequence, error) {
return v.latestSequence(userGrantTable)
}
func (v *View) ProcessedUserGrantSequence(event *models.Event) error {
return v.saveCurrentSequence(userGrantTable, event)
}
func (v *View) UpdateUserGrantSpoolerRunTimestamp() error {
return v.updateSpoolerRunSequence(userGrantTable)
}
func (v *View) GetLatestUserGrantFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
return v.latestFailedEvent(userGrantTable, sequence)
}
func (v *View) ProcessedUserGrantFailedEvent(failedEvent *repository.FailedEvent) error {
return v.saveFailedEvent(failedEvent)
}

View File

@ -23,6 +23,7 @@ type AuthRequest struct {
UiLocales []string
LoginHint string
MaxAuthAge *time.Duration
InstanceID string
Request Request
levelOfAssurance LevelOfAssurance

View File

@ -21,7 +21,7 @@ func NewAggregate(
ID: id,
Type: typ,
ResourceOwner: authz.GetCtxData(ctx).OrgID,
Tenant: authz.GetCtxData(ctx).TenantID,
InstanceID: authz.GetInstance(ctx).ID,
Version: version,
}
@ -50,7 +50,7 @@ func AggregateFromWriteModel(
ID: wm.AggregateID,
Type: typ,
ResourceOwner: wm.ResourceOwner,
Tenant: wm.Tenant,
InstanceID: wm.InstanceID,
Version: version,
}
}
@ -63,8 +63,8 @@ type Aggregate struct {
Type AggregateType `json:"-"`
//ResourceOwner is the org this aggregates belongs to
ResourceOwner string `json:"-"`
//Tenant is the system this aggregate belongs to
Tenant string `json:"-"`
//InstanceID is the instance this aggregate belongs to
InstanceID string `json:"-"`
//Version is the semver this aggregate represents
Version Version `json:"-"`
}

View File

@ -79,7 +79,7 @@ func BaseEventFromRepo(event *repository.Event) *BaseEvent {
ID: event.AggregateID,
Type: AggregateType(event.AggregateType),
ResourceOwner: event.ResourceOwner.String,
Tenant: event.Tenant.String,
InstanceID: event.InstanceID.String,
Version: Version(event.Version),
},
EventType: EventType(event.Type),

View File

@ -41,7 +41,7 @@ func (es *Eventstore) Health(ctx context.Context) error {
//Push pushes the events in a single transaction
// an event needs at least an aggregate
func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error) {
events, constraints, err := commandsToRepository(authz.GetCtxData(ctx).TenantID, cmds)
events, constraints, err := commandsToRepository(authz.GetInstance(ctx).ID, cmds)
if err != nil {
return nil, err
}
@ -59,7 +59,7 @@ func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error
return eventReaders, nil
}
func commandsToRepository(tenantID string, cmds []Command) (events []*repository.Event, constraints []*repository.UniqueConstraint, err error) {
func commandsToRepository(instanceID string, cmds []Command) (events []*repository.Event, constraints []*repository.UniqueConstraint, err error) {
events = make([]*repository.Event, len(cmds))
for i, cmd := range cmds {
data, err := EventData(cmd)
@ -82,7 +82,7 @@ func commandsToRepository(tenantID string, cmds []Command) (events []*repository
AggregateID: cmd.Aggregate().ID,
AggregateType: repository.AggregateType(cmd.Aggregate().Type),
ResourceOwner: sql.NullString{String: cmd.Aggregate().ResourceOwner, Valid: cmd.Aggregate().ResourceOwner != ""},
Tenant: sql.NullString{String: tenantID, Valid: tenantID != ""},
InstanceID: sql.NullString{String: instanceID, Valid: instanceID != ""},
EditorService: cmd.EditorService(),
EditorUser: cmd.EditorUser(),
Type: repository.EventType(cmd.Type()),
@ -113,7 +113,7 @@ func uniqueConstraintsToRepository(constraints []*EventUniqueConstraint) (unique
//Filter filters the stored events based on the searchQuery
// and maps the events to the defined event structs
func (es *Eventstore) Filter(ctx context.Context, queryFactory *SearchQueryBuilder) ([]Event, error) {
query, err := queryFactory.build(authz.GetCtxData(ctx).TenantID)
query, err := queryFactory.build(authz.GetInstance(ctx).ID)
if err != nil {
return nil, err
}
@ -170,7 +170,7 @@ func (es *Eventstore) FilterToReducer(ctx context.Context, searchQuery *SearchQu
//LatestSequence filters the latest sequence for the given search query
func (es *Eventstore) LatestSequence(ctx context.Context, queryFactory *SearchQueryBuilder) (uint64, error) {
query, err := queryFactory.build(authz.GetCtxData(ctx).TenantID)
query, err := queryFactory.build(authz.GetInstance(ctx).ID)
if err != nil {
return 0, err
}

View File

@ -29,7 +29,7 @@ func newTestEvent(id, description string, data func() interface{}, checkPrevious
data: data,
shouldCheckPrevious: checkPrevious,
BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"),
service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate(authz.NewMockContext("zitadel", "caos", "adlerhurst"), id, "test.aggregate", "v1"),
"test.event",
),
@ -344,8 +344,8 @@ func Test_eventData(t *testing.T) {
func TestEventstore_aggregatesToEvents(t *testing.T) {
type args struct {
tenantID string
events []Command
instanceID string
events []Command
}
type res struct {
wantErr bool
@ -359,7 +359,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
{
name: "one aggregate one event",
args: args{
tenantID: "tenant",
instanceID: "instanceID",
events: []Command{
newTestEvent(
"1",
@ -380,7 +380,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "tenant", Valid: true},
InstanceID: sql.NullString{String: "instanceID", Valid: true},
Type: "test.event",
Version: "v1",
},
@ -390,7 +390,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
{
name: "one aggregate multiple events",
args: args{
tenantID: "tenant",
instanceID: "instanceID",
events: []Command{
newTestEvent(
"1",
@ -418,7 +418,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "tenant", Valid: true},
InstanceID: sql.NullString{String: "instanceID", Valid: true},
Type: "test.event",
Version: "v1",
},
@ -429,7 +429,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "tenant", Valid: true},
InstanceID: sql.NullString{String: "instanceID", Valid: true},
Type: "test.event",
Version: "v1",
},
@ -439,7 +439,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
{
name: "invalid data",
args: args{
tenantID: "tenant",
instanceID: "instanceID",
events: []Command{
newTestEvent(
"1",
@ -460,7 +460,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{
&testEvent{
BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"),
service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"",
@ -485,7 +485,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{
&testEvent{
BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"),
service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"id",
@ -510,7 +510,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{
&testEvent{
BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"),
service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"id",
@ -535,7 +535,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{
&testEvent{
BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"),
service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"id",
@ -560,7 +560,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{
&testEvent{
BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "", "editorUser"), "editorService"),
service.WithService(authz.NewMockContext("instanceID", "", "editorUser"), "editorService"),
NewAggregate(
authz.NewMockContext("zitadel", "", "adlerhurst"),
"id",
@ -585,7 +585,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "", Valid: false},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -630,7 +630,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -641,7 +641,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -654,7 +654,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -665,7 +665,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
events, _, err := commandsToRepository(tt.args.tenantID, tt.args.events)
events, _, err := commandsToRepository(tt.args.instanceID, tt.args.events)
if (err != nil) != tt.res.wantErr {
t.Errorf("Eventstore.aggregatesToEvents() error = %v, wantErr %v", err, tt.res.wantErr)
return
@ -772,7 +772,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -816,7 +816,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -827,7 +827,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -882,7 +882,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -893,7 +893,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},
@ -906,7 +906,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService",
EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"},
InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event",
Version: "v1",
},

View File

@ -8,15 +8,16 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/caos/zitadel/internal/eventstore"
)
type mockExpectation func(sqlmock.Sqlmock)
func expectFailureCount(tableName string, projectionName string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
WithArgs(projectionName, failedSeq).
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\ AND instance_id = \$3\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
WithArgs(projectionName, failedSeq, instanceID).
WillReturnRows(
sqlmock.NewRows([]string{"failure_count"}).
AddRow(failureCount),
@ -24,10 +25,10 @@ func expectFailureCount(tableName string, projectionName string, failedSeq, fail
}
}
func expectUpdateFailureCount(tableName string, projectionName string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error\) VALUES \(\$1, \$2, \$3, \$4\)`).
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\)`).
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID).WillReturnResult(sqlmock.NewResult(1, 1))
}
}

View File

@ -4,15 +4,16 @@ import (
"database/sql"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/handler"
)
const (
setFailureCountStmtFormat = "UPSERT INTO %s" +
" (projection_name, failed_sequence, failure_count, error)" +
" VALUES ($1, $2, $3, $4)"
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2)" +
" (projection_name, failed_sequence, failure_count, error, instance_id)" +
" VALUES ($1, $2, $3, $4, $5)"
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
" SELECT IF(" +
"EXISTS(SELECT failure_count FROM failures)," +
" (SELECT failure_count FROM failures)," +
@ -21,31 +22,31 @@ const (
)
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
failureCount, err := h.failureCount(tx, stmt.Sequence)
failureCount, err := h.failureCount(tx, stmt.Sequence, stmt.InstanceID)
if err != nil {
logging.WithFields("projection", h.ProjectionName, "seq", stmt.Sequence).WithError(err).Warn("unable to get failure count")
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).WithError(err).Warn("unable to get failure count")
return false
}
failureCount += 1
err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr)
logging.WithFields("projection", h.ProjectionName, "seq", stmt.Sequence).OnError(err).Warn("unable to update failure count")
err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr, stmt.InstanceID)
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).OnError(err).Warn("unable to update failure count")
return failureCount >= h.maxFailureCount
}
func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64) (count uint, err error) {
row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq)
func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64, instanceID string) (count uint, err error) {
row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq, instanceID)
if err = row.Err(); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
}
if err = row.Scan(&count); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scann count")
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
}
return count, nil
}
func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error) error {
_, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error())
func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error, instanceID string) error {
_, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error(), instanceID)
if dbErr != nil {
return errors.ThrowInternal(dbErr, "CRDB-4Ht4x", "set failure count failed")
}

View File

@ -26,7 +26,8 @@ type StatementHandlerConfig struct {
MaxFailureCount uint
BulkLimit uint64
Reducers []handler.AggregateReducer
Reducers []handler.AggregateReducer
InitCheck *handler.Check
}
type StatementHandler struct {
@ -75,6 +76,9 @@ func NewStatementHandler(
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
}
err := h.Init(ctx, config.InitCheck)
logging.OnError(err).Fatal("unable to initialize projections")
go h.ProjectionHandler.Process(
ctx,
h.reduce,
@ -214,7 +218,7 @@ func (h *StatementHandler) executeStmts(
continue
}
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "seq", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
break
}
err := h.executeStmt(tx, stmt)

View File

@ -28,6 +28,7 @@ type testEvent struct {
sequence uint64
previousSequence uint64
aggregateType eventstore.AggregateType
instanceID string
}
func (e *testEvent) Sequence() uint64 {
@ -36,7 +37,8 @@ func (e *testEvent) Sequence() uint64 {
func (e *testEvent) Aggregate() eventstore.Aggregate {
return eventstore.Aggregate{
Type: e.aggregateType,
Type: e.aggregateType,
InstanceID: e.instanceID,
}
}
@ -786,6 +788,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg",
sequence: 5,
previousSequence: 0,
instanceID: "instanceID",
},
[]handler.Column{
{
@ -798,6 +801,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg",
sequence: 6,
previousSequence: 5,
instanceID: "instanceID",
},
[]handler.Column{
{
@ -810,6 +814,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg",
sequence: 7,
previousSequence: 6,
instanceID: "instanceID",
},
[]handler.Column{
{
@ -830,8 +835,8 @@ func TestStatementHandler_executeStmts(t *testing.T) {
expectSavePoint(),
expectCreateErr("my_projection", []string{"col"}, []string{"$1"}, sql.ErrConnDone),
expectSavePointRollback(),
expectFailureCount("failed_events", "my_projection", 6, 3),
expectUpdateFailureCount("failed_events", "my_projection", 6, 4),
expectFailureCount("failed_events", "my_projection", "instanceID", 6, 3),
expectUpdateFailureCount("failed_events", "my_projection", "instanceID", 6, 4),
},
idx: 0,
},
@ -850,6 +855,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg",
sequence: 5,
previousSequence: 0,
instanceID: "instanceID",
},
[]handler.Column{
{
@ -862,6 +868,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg",
sequence: 6,
previousSequence: 5,
instanceID: "instanceID",
},
[]handler.Column{
{
@ -874,6 +881,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg",
sequence: 7,
previousSequence: 6,
instanceID: "instanceID",
},
[]handler.Column{
{
@ -894,8 +902,8 @@ func TestStatementHandler_executeStmts(t *testing.T) {
expectSavePoint(),
expectCreateErr("my_projection", []string{"col2"}, []string{"$1"}, sql.ErrConnDone),
expectSavePointRollback(),
expectFailureCount("failed_events", "my_projection", 6, 4),
expectUpdateFailureCount("failed_events", "my_projection", 6, 5),
expectFailureCount("failed_events", "my_projection", "instanceID", 6, 4),
expectUpdateFailureCount("failed_events", "my_projection", "instanceID", 6, 5),
expectSavePoint(),
expectCreate("my_projection", []string{"col3"}, []string{"$1"}),
expectSavePointRelease(),

View File

@ -0,0 +1,320 @@
package crdb
import (
"context"
"errors"
"fmt"
"strings"
"github.com/caos/logging"
"github.com/lib/pq"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/handler"
)
type Table struct {
columns []*Column
primaryKey PrimaryKey
indices []*Index
}
func NewTable(columns []*Column, key PrimaryKey, indices ...*Index) *Table {
return &Table{
columns: columns,
primaryKey: key,
indices: indices,
}
}
type SuffixedTable struct {
Table
suffix string
}
func NewSuffixedTable(columns []*Column, key PrimaryKey, suffix string, indices ...*Index) *SuffixedTable {
return &SuffixedTable{
Table: Table{
columns: columns,
primaryKey: key,
indices: indices,
},
suffix: suffix,
}
}
type Column struct {
Name string
Type ColumnType
nullable bool
defaultValue interface{}
deleteCascade string
}
type ColumnOption func(*Column)
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column {
column := &Column{
Name: name,
Type: columnType,
nullable: false,
defaultValue: nil,
}
for _, opt := range opts {
opt(column)
}
return column
}
func Nullable() ColumnOption {
return func(c *Column) {
c.nullable = true
}
}
func Default(value interface{}) ColumnOption {
return func(c *Column) {
c.defaultValue = value
}
}
func DeleteCascade(column string) ColumnOption {
return func(c *Column) {
c.deleteCascade = column
}
}
type PrimaryKey []string
func NewPrimaryKey(columnNames ...string) PrimaryKey {
return columnNames
}
type ColumnType int32
const (
ColumnTypeText ColumnType = iota
ColumnTypeTextArray
ColumnTypeJSONB
ColumnTypeBytes
ColumnTypeTimestamp
ColumnTypeEnum
ColumnTypeEnumArray
ColumnTypeInt64
ColumnTypeBool
)
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
i := &Index{
Name: name,
Columns: columns,
bucketCount: 0,
}
for _, opt := range opts {
opt(i)
}
return i
}
type Index struct {
Name string
Columns []string
bucketCount uint16
}
type indexOpts func(*Index)
func Hash(bucketsCount uint16) indexOpts {
return func(i *Index) {
i.bucketCount = bucketsCount
}
}
//Init implements handler.Init
func (h *StatementHandler) Init(ctx context.Context, checks ...*handler.Check) error {
for _, check := range checks {
if check == nil || check.IsNoop() {
return nil
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return caos_errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
}
for i, execute := range check.Executes {
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("executing check")
next, err := execute(h.client, h.ProjectionName)
if err != nil {
tx.Rollback()
return err
}
if !next {
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("skipping next check")
break
}
}
if err := tx.Commit(); err != nil {
return err
}
}
return nil
}
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
return createTableStatement(table, config.tableName, "")
}
executes := make([]func(handler.Executer, string) (bool, error), len(table.indices)+1)
executes[0] = execNextIfExists(config, create, opts, true)
for i, index := range table.indices {
executes[i+1] = execNextIfExists(config, createIndexStatement(index), opts, true)
}
return &handler.Check{
Executes: executes,
}
}
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
stmt := createTableStatement(primaryTable, config.tableName, "")
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, true),
},
}
}
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
var stmt string
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
stmt += createViewStatement(config.tableName, selectStmt)
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, false),
},
}
}
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
return func(handler handler.Executer, name string) (bool, error) {
err := exec(config, q, opts)(handler, name)
if isErrAlreadyExists(err) {
return executeNext, nil
}
return false, err
}
}
func isErrAlreadyExists(err error) bool {
caosErr := &caos_errs.CaosError{}
if !errors.As(err, &caosErr) {
return false
}
sqlErr, ok := caosErr.GetParent().(*pq.Error)
if !ok {
return false
}
return sqlErr.Routine == "NewRelationAlreadyExistsError"
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, index := range table.indices {
stmt += fmt.Sprintf(", INDEX %s (%s)", index.Name, strings.Join(index.Columns, ","))
}
return stmt + ");"
}
func createViewStatement(viewName string, selectStmt string) string {
return fmt.Sprintf("CREATE VIEW %s AS %s",
viewName,
selectStmt,
)
}
func createIndexStatement(index *Index) func(config execConfig) string {
return func(config execConfig) string {
stmt := fmt.Sprintf("CREATE INDEX %s ON %s (%s)",
index.Name,
config.tableName,
strings.Join(index.Columns, ","),
)
if index.bucketCount == 0 {
return stmt + ";"
}
return fmt.Sprintf("SET experimental_enable_hash_sharded_indexes=on; %s USING HASH WITH BUCKET_COUNT = %d;",
stmt, index.bucketCount)
}
}
func createColumnsStatement(cols []*Column, tableName string) string {
columns := make([]string, len(cols))
for i, col := range cols {
column := col.Name + " " + columnType(col.Type)
if !col.nullable {
column += " NOT NULL"
}
if col.defaultValue != nil {
column += " DEFAULT " + defaultValue(col.defaultValue)
}
if col.deleteCascade != "" {
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
}
columns[i] = column
}
return strings.Join(columns, ",")
}
func defaultValue(value interface{}) string {
switch v := value.(type) {
case string:
return "'" + v + "'"
default:
return fmt.Sprintf("%v", v)
}
}
func columnType(columnType ColumnType) string {
switch columnType {
case ColumnTypeText:
return "TEXT"
case ColumnTypeTextArray:
return "TEXT[]"
case ColumnTypeTimestamp:
return "TIMESTAMPTZ"
case ColumnTypeEnum:
return "SMALLINT"
case ColumnTypeEnumArray:
return "SMALLINT[]"
case ColumnTypeInt64:
return "BIGINT"
case ColumnTypeBool:
return "BOOLEAN"
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTES"
default:
panic("") //TODO: remove?
return ""
}
}

View File

@ -37,7 +37,7 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
workerName, err := os.Hostname()
if err != nil || workerName == "" {
workerName, err = id.SonyFlakeGenerator.Next()
logging.Log("CRDB-bdO56").OnError(err).Panic("unable to generate lockID")
logging.OnError(err).Panic("unable to generate lockID")
}
return &locker{
client: client,

View File

@ -6,7 +6,7 @@ import (
"github.com/lib/pq"
"github.com/caos/zitadel/internal/errors"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/handler"
)
@ -46,6 +46,7 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
}
@ -71,6 +72,7 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
}
@ -104,6 +106,7 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
}
@ -129,6 +132,7 @@ func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition,
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
}
@ -138,6 +142,7 @@ func NewNoOpStatement(event eventstore.Event) *handler.Statement {
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
}
}
@ -153,6 +158,7 @@ func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Ex
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: multiExec(execs),
}
}
@ -278,6 +284,7 @@ func NewCopyStatement(event eventstore.Event, cols []handler.Column, conds []han
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
}
@ -327,7 +334,7 @@ func exec(config execConfig, q query, opts []execOption) Exec {
}
if _, err := ex.Exec(q(config), config.args...); err != nil {
return errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
return caos_errs.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
}
return nil

View File

@ -608,6 +608,7 @@ func TestNewNoOpStatement(t *testing.T) {
aggregateType: "agg",
sequence: 5,
previousSequence: 3,
instanceID: "instanceID",
},
},
want: &handler.Statement{
@ -615,6 +616,7 @@ func TestNewNoOpStatement(t *testing.T) {
Execute: nil,
Sequence: 5,
PreviousSequence: 3,
InstanceID: "instanceID",
},
},
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore"
)
@ -270,7 +271,7 @@ func (h *ProjectionHandler) fetchBulkStmts(
for _, event := range events {
if err = h.processEvent(ctx, event, reduce); err != nil {
logging.WithFields("projection", h.ProjectionName, "seq", event.Sequence()).WithError(err).Warn("unable to process event in bulk")
logging.WithFields("projection", h.ProjectionName, "sequence", event.Sequence(), "instanceID", event.Aggregate().InstanceID).WithError(err).Warn("unable to process event in bulk")
return false, err
}
}

View File

@ -0,0 +1,14 @@
package handler
import "context"
//Init initializes the projection with the given check
type Init func(context.Context, *Check) error
type Check struct {
Executes []func(ex Executer, projectionName string) (bool, error)
}
func (c *Check) IsNoop() bool {
return len(c.Executes) == 0
}

View File

@ -27,6 +27,7 @@ type Statement struct {
AggregateType eventstore.AggregateType
Sequence uint64
PreviousSequence uint64
InstanceID string
Execute func(ex Executer, projectionName string) error
}

View File

@ -12,7 +12,7 @@ type ReadModel struct {
ChangeDate time.Time `json:"-"`
Events []Event `json:"-"`
ResourceOwner string `json:"-"`
Tenant string `json:"-"`
InstanceID string `json:"-"`
}
//AppendEvents adds all the events to the read model.
@ -35,8 +35,8 @@ func (rm *ReadModel) Reduce() error {
if rm.ResourceOwner == "" {
rm.ResourceOwner = rm.Events[0].Aggregate().ResourceOwner
}
if rm.Tenant == "" {
rm.Tenant = rm.Events[0].Aggregate().Tenant
if rm.InstanceID == "" {
rm.InstanceID = rm.Events[0].Aggregate().InstanceID
}
if rm.CreationDate.IsZero() {

View File

@ -56,9 +56,9 @@ type Event struct {
// an aggregate can only be managed by one organisation
// use the ID of the org
ResourceOwner sql.NullString
//Tenant is the system where this event belongs to
// use the ID of the tenant
Tenant sql.NullString
//InstanceID is the instance where this event belongs to
// use the ID of the instance
InstanceID sql.NullString
}
//EventType is the description of the change

View File

@ -66,8 +66,8 @@ const (
FieldSequence
//FieldResourceOwner represents the resource owner field
FieldResourceOwner
//FieldTenant represents the tenant field
FieldTenant
//FieldInstanceID represents the instance id field
FieldInstanceID
//FieldEditorService represents the editor service field
FieldEditorService
//FieldEditorUser represents the editor user field

View File

@ -30,7 +30,7 @@ const (
" SELECT MAX(event_sequence) seq, 1 join_me" +
" FROM eventstore.events" +
" WHERE aggregate_type = $2" +
" AND (CASE WHEN $9::STRING IS NULL THEN tenant is null else tenant = $9::STRING END)" +
" AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
") AS agg_type " +
// combined with
"LEFT JOIN " +
@ -39,7 +39,7 @@ const (
" SELECT event_sequence seq, resource_owner ro, 1 join_me" +
" FROM eventstore.events" +
" WHERE aggregate_type = $2 AND aggregate_id = $3" +
" AND (CASE WHEN $9::STRING IS NULL THEN tenant is null else tenant = $9::STRING END)" +
" AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
" ORDER BY event_sequence DESC" +
" LIMIT 1" +
") AS agg USING(join_me)" +
@ -54,7 +54,7 @@ const (
" editor_user," +
" editor_service," +
" resource_owner," +
" tenant," +
" instance_id," +
" event_sequence," +
" previous_aggregate_sequence," +
" previous_aggregate_type_sequence" +
@ -70,12 +70,12 @@ const (
" $6::VARCHAR AS editor_user," +
" $7::VARCHAR AS editor_service," +
" IFNULL((resource_owner), $8::VARCHAR) AS resource_owner," +
" $9::VARCHAR AS tenant," +
" $9::VARCHAR AS instance_id," +
" NEXTVAL(CONCAT('eventstore.', IFNULL($9, 'system'), '_seq'))," +
" aggregate_sequence AS previous_aggregate_sequence," +
" aggregate_type_sequence AS previous_aggregate_type_sequence " +
"FROM previous_data " +
"RETURNING id, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, creation_date, resource_owner, tenant"
"RETURNING id, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, creation_date, resource_owner, instance_id"
uniqueInsert = `INSERT INTO eventstore.unique_constraints
(
@ -120,8 +120,8 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
event.EditorUser,
event.EditorService,
event.ResourceOwner,
event.Tenant,
).Scan(&event.ID, &event.Sequence, &previousAggregateSequence, &previousAggregateTypeSequence, &event.CreationDate, &event.ResourceOwner, &event.Tenant)
event.InstanceID,
).Scan(&event.ID, &event.Sequence, &previousAggregateSequence, &previousAggregateTypeSequence, &event.CreationDate, &event.ResourceOwner, &event.InstanceID)
event.PreviousAggregateSequence = uint64(previousAggregateSequence)
event.PreviousAggregateTypeSequence = uint64(previousAggregateTypeSequence)
@ -132,7 +132,7 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
"aggregateId", event.AggregateID,
"aggregateType", event.AggregateType,
"eventType", event.Type,
"tenant", event.Tenant,
"instanceID", event.InstanceID,
).WithError(err).Info("query failed")
return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event")
}
@ -229,7 +229,7 @@ func (db *CRDB) eventQuery() string {
", editor_service" +
", editor_user" +
", resource_owner" +
", tenant" +
", instance_id" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
@ -250,8 +250,8 @@ func (db *CRDB) columnName(col repository.Field) string {
return "event_sequence"
case repository.FieldResourceOwner:
return "resource_owner"
case repository.FieldTenant:
return "tenant"
case repository.FieldInstanceID:
return "instance_id"
case repository.FieldEditorService:
return "editor_service"
case repository.FieldEditorUser:

View File

@ -109,7 +109,7 @@ func eventsScanner(scanner scan, dest interface{}) (err error) {
&event.EditorService,
&event.EditorUser,
&event.ResourceOwner,
&event.Tenant,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&event.Version,

View File

@ -130,7 +130,7 @@ func Test_prepareColumns(t *testing.T) {
dest: &[]*repository.Event{},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
expected: []*repository.Event{
{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
},
@ -146,7 +146,7 @@ func Test_prepareColumns(t *testing.T) {
dest: []*repository.Event{},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument,
},
},
@ -158,7 +158,7 @@ func Test_prepareColumns(t *testing.T) {
dbErr: sql.ErrConnDone,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsInternal,
},
},
@ -592,7 +592,7 @@ func Test_query_events_mocked(t *testing.T) {
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")},
),
},
@ -621,7 +621,7 @@ func Test_query_events_mocked(t *testing.T) {
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence LIMIT \$2`,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)},
),
},
@ -650,7 +650,7 @@ func Test_query_events_mocked(t *testing.T) {
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC LIMIT \$2`,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)},
),
},
@ -679,7 +679,7 @@ func Test_query_events_mocked(t *testing.T) {
},
fields: fields{
mock: newMockClient(t).expectQueryErr(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")},
sql.ErrConnDone),
},
@ -708,7 +708,7 @@ func Test_query_events_mocked(t *testing.T) {
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")},
&repository.Event{Sequence: 100}),
},
@ -776,7 +776,7 @@ func Test_query_events_mocked(t *testing.T) {
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) OR \( aggregate_type = \$2 AND aggregate_id = \$3 \) ORDER BY event_sequence DESC LIMIT \$4`,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) OR \( aggregate_type = \$2 AND aggregate_id = \$3 \) ORDER BY event_sequence DESC LIMIT \$4`,
[]driver.Value{repository.AggregateType("user"), repository.AggregateType("org"), "asdf42", uint64(5)},
),
},

View File

@ -12,7 +12,7 @@ type SearchQueryBuilder struct {
limit uint64
desc bool
resourceOwner string
tenant string
instanceID string
queries []*SearchQuery
}
@ -68,9 +68,9 @@ func (factory *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQu
return factory
}
//Tenant defines the tenant (system) of the events
func (factory *SearchQueryBuilder) Tenant(tenant string) *SearchQueryBuilder {
factory.tenant = tenant
//InstanceID defines the instanceID (system) of the events
func (factory *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
factory.instanceID = instanceID
return factory
}
@ -145,13 +145,13 @@ func (query *SearchQuery) Builder() *SearchQueryBuilder {
return query.builder
}
func (builder *SearchQueryBuilder) build(tenantID string) (*repository.SearchQuery, error) {
func (builder *SearchQueryBuilder) build(instanceID string) (*repository.SearchQuery, error) {
if builder == nil ||
len(builder.queries) < 1 ||
builder.columns.Validate() != nil {
return nil, errors.ThrowPreconditionFailed(nil, "MODEL-4m9gs", "builder invalid")
}
builder.tenant = tenantID
builder.instanceID = instanceID
filters := make([][]*repository.Filter, len(builder.queries))
for i, query := range builder.queries {
@ -163,7 +163,7 @@ func (builder *SearchQueryBuilder) build(tenantID string) (*repository.SearchQue
query.eventSequenceGreaterFilter,
query.eventSequenceLessFilter,
query.builder.resourceOwnerFilter,
query.builder.tenantFilter,
query.builder.instanceIDFilter,
} {
if filter := f(); filter != nil {
if err := filter.Validate(); err != nil {
@ -247,11 +247,11 @@ func (builder *SearchQueryBuilder) resourceOwnerFilter() *repository.Filter {
return repository.NewFilter(repository.FieldResourceOwner, builder.resourceOwner, repository.OperationEquals)
}
func (builder *SearchQueryBuilder) tenantFilter() *repository.Filter {
if builder.tenant == "" {
func (builder *SearchQueryBuilder) instanceIDFilter() *repository.Filter {
if builder.instanceID == "" {
return nil
}
return repository.NewFilter(repository.FieldTenant, builder.tenant, repository.OperationEquals)
return repository.NewFilter(repository.FieldInstanceID, builder.instanceID, repository.OperationEquals)
}
func (query *SearchQuery) eventDataFilter() *repository.Filter {

View File

@ -224,9 +224,9 @@ func TestSearchQuerybuilderSetters(t *testing.T) {
func TestSearchQuerybuilderBuild(t *testing.T) {
type args struct {
columns Columns
setters []func(*SearchQueryBuilder) *SearchQueryBuilder
tenant string
columns Columns
setters []func(*SearchQueryBuilder) *SearchQueryBuilder
instanceID string
}
type res struct {
isErr func(err error) bool
@ -622,7 +622,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
},
},
{
name: "filter aggregate type and tenant",
name: "filter aggregate type and instanceID",
args: args{
columns: ColumnsEvent,
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{
@ -630,7 +630,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
testSetAggregateTypes("user"),
),
},
tenant: "tenant",
instanceID: "instanceID",
},
res: res{
isErr: nil,
@ -641,7 +641,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals),
repository.NewFilter(repository.FieldTenant, "tenant", repository.OperationEquals),
repository.NewFilter(repository.FieldInstanceID, "instanceID", repository.OperationEquals),
},
},
},
@ -668,7 +668,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
for _, f := range tt.args.setters {
builder = f(builder)
}
query, err := builder.build(tt.args.tenant)
query, err := builder.build(tt.args.instanceID)
if tt.res.isErr != nil && !tt.res.isErr(err) {
t.Errorf("wrong error(%T): %v", err, err)
return

View File

@ -122,7 +122,7 @@ func mapEventToV1Event(event Event) *models.Event {
AggregateType: models.AggregateType(event.Aggregate().Type),
AggregateID: event.Aggregate().ID,
ResourceOwner: event.Aggregate().ResourceOwner,
Tenant: event.Aggregate().Tenant,
InstanceID: event.Aggregate().InstanceID,
EditorService: event.EditorService(),
EditorUser: event.EditorUser(),
Data: event.DataAsBytes(),

View File

@ -7,15 +7,16 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
const (
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1`
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1`
)
var (
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "tenant", "aggregate_type", "aggregate_id", "aggregate_version"}
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "instance_id", "aggregate_type", "aggregate_id", "aggregate_version"}
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String()
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String()
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
@ -23,7 +24,7 @@ var (
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String()
expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` +
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, tenant, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` +
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` +
`SELECT \$1, \$2, \$3, \$4, COALESCE\(\$5, now\(\)\), \$6, \$7, \$8, \$9, \$10, \$11 ` +
`WHERE EXISTS \(` +
`SELECT 1 FROM eventstore\.events WHERE aggregate_type = \$12 AND aggregate_id = \$13 HAVING MAX\(event_sequence\) = \$14 OR \(\$14::BIGINT IS NULL AND COUNT\(\*\) = 0\)\) ` +
@ -99,7 +100,7 @@ func (db *dbMock) expectRollback(err error) *dbMock {
func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *dbMock {
db.mock.ExpectQuery(expectedInsertStatement).
WithArgs(
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.Tenant, Sequence(e.PreviousSequence),
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.InstanceID, Sequence(e.PreviousSequence),
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
).
WillReturnRows(
@ -113,7 +114,7 @@ func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *d
func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock {
db.mock.ExpectQuery(expectedInsertStatement).
WithArgs(
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.Tenant, Sequence(e.PreviousSequence),
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.InstanceID, Sequence(e.PreviousSequence),
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
).
WillReturnError(sql.ErrTxDone)
@ -124,7 +125,7 @@ func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock {
func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, eventCount int) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := 0; i < eventCount; i++ {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0")
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsLimitFormat).
WithArgs(aggregateType, limit).
@ -135,7 +136,7 @@ func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, ev
func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := eventCount; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0")
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsDescFormat).
WillReturnRows(rows)
@ -145,7 +146,7 @@ func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *
func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0")
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit).
WithArgs(aggregateType, aggregateID, limit).
@ -156,7 +157,7 @@ func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID
func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0")
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit).
WithArgs(aggregateType, aggregateID, limit).

View File

@ -8,9 +8,10 @@ import (
"strings"
"github.com/caos/logging"
"github.com/lib/pq"
z_errors "github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/lib/pq"
)
const (
@ -23,7 +24,7 @@ const (
", editor_service" +
", editor_user" +
", resource_owner" +
", tenant" +
", instance_id" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
@ -117,7 +118,7 @@ func prepareColumns(columns es_models.Columns) (string, func(s scan, dest interf
&event.EditorService,
&event.EditorUser,
&event.ResourceOwner,
&event.Tenant,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&event.AggregateVersion,
@ -177,8 +178,8 @@ func getField(field es_models.Field) string {
return "event_sequence"
case es_models.Field_ResourceOwner:
return "resource_owner"
case es_models.Field_Tenant:
return "tenant"
case es_models.Field_InstanceID:
return "instance_id"
case es_models.Field_EditorService:
return "editor_service"
case es_models.Field_EditorUser:

View File

@ -6,9 +6,10 @@ import (
"testing"
"time"
"github.com/lib/pq"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/lib/pq"
)
func Test_numberPlaceholder(t *testing.T) {
@ -80,7 +81,7 @@ func Test_getField(t *testing.T) {
es_models.Field_AggregateID: "aggregate_id",
es_models.Field_LatestSequence: "event_sequence",
es_models.Field_ResourceOwner: "resource_owner",
es_models.Field_Tenant: "tenant",
es_models.Field_InstanceID: "instance_id",
es_models.Field_EditorService: "editor_service",
es_models.Field_EditorUser: "editor_user",
es_models.Field_EventType: "event_type",
@ -235,7 +236,7 @@ func Test_prepareColumns(t *testing.T) {
dest: new(es_models.Event),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbRow: []interface{}{time.Time{}, es_models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", "", es_models.AggregateType("user"), "hodor", es_models.Version("")},
expected: es_models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
},
@ -247,7 +248,7 @@ func Test_prepareColumns(t *testing.T) {
dest: new(uint64),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument,
},
},
@ -259,7 +260,7 @@ func Test_prepareColumns(t *testing.T) {
dbErr: sql.ErrConnDone,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsInternal,
},
},
@ -430,7 +431,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user")},
},
@ -441,7 +442,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
@ -453,7 +454,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,

View File

@ -23,7 +23,7 @@ type Aggregate struct {
editorService string
editorUser string
resourceOwner string
tenant string
instanceID string
Events []*Event
Precondition *precondition
}
@ -56,7 +56,7 @@ func (a *Aggregate) AppendEvent(typ EventType, payload interface{}) (*Aggregate,
EditorService: a.editorService,
EditorUser: a.editorUser,
ResourceOwner: a.resourceOwner,
Tenant: a.tenant,
InstanceID: a.instanceID,
}
a.Events = append(a.Events, e)

View File

@ -18,9 +18,10 @@ type option func(*Aggregate)
func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ AggregateType, version Version, previousSequence uint64, opts ...option) (*Aggregate, error) {
ctxData := authz.GetCtxData(ctx)
instance := authz.GetInstance(ctx)
editorUser := ctxData.UserID
resourceOwner := ctxData.OrgID
tenant := ctxData.TenantID
instanceID := instance.ID
aggregate := &Aggregate{
ID: id,
@ -31,7 +32,7 @@ func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ Aggr
editorService: c.serviceName,
editorUser: editorUser,
resourceOwner: resourceOwner,
tenant: tenant,
instanceID: instanceID,
}
for _, opt := range opts {

View File

@ -28,7 +28,7 @@ type Event struct {
EditorService string
EditorUser string
ResourceOwner string
Tenant string
InstanceID string
}
func eventData(i interface{}) ([]byte, error) {

View File

@ -11,5 +11,5 @@ const (
Field_EditorUser
Field_EventType
Field_CreationDate
Field_Tenant
Field_InstanceID
)

View File

@ -8,7 +8,7 @@ type ObjectRoot struct {
AggregateID string `json:"-"`
Sequence uint64 `json:"-"`
ResourceOwner string `json:"-"`
Tenant string `json:"-"`
InstanceID string `json:"-"`
CreationDate time.Time `json:"-"`
ChangeDate time.Time `json:"-"`
}
@ -22,8 +22,8 @@ func (o *ObjectRoot) AppendEvent(event *Event) {
if o.ResourceOwner == "" {
o.ResourceOwner = event.ResourceOwner
}
if o.Tenant == "" {
o.Tenant = event.Tenant
if o.InstanceID == "" {
o.InstanceID = event.InstanceID
}
o.ChangeDate = event.CreationDate

View File

@ -4,6 +4,7 @@ import (
"time"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
)
@ -17,7 +18,7 @@ type SearchQueryFactory struct {
sequenceTo uint64
eventTypes []EventType
resourceOwner string
tenant string
instanceID string
creationDate time.Time
}
@ -63,8 +64,8 @@ func FactoryFromSearchQuery(query *SearchQuery) *SearchQueryFactory {
}
case Field_ResourceOwner:
factory = factory.ResourceOwner(filter.value.(string))
case Field_Tenant:
factory = factory.Tenant(filter.value.(string))
case Field_InstanceID:
factory = factory.InstanceID(filter.value.(string))
case Field_EventType:
factory = factory.EventTypes(filter.value.([]EventType)...)
case Field_EditorService, Field_EditorUser:
@ -123,8 +124,8 @@ func (factory *SearchQueryFactory) ResourceOwner(resourceOwner string) *SearchQu
return factory
}
func (factory *SearchQueryFactory) Tenant(tenant string) *SearchQueryFactory {
factory.tenant = tenant
func (factory *SearchQueryFactory) InstanceID(instanceID string) *SearchQueryFactory {
factory.instanceID = instanceID
return factory
}
@ -159,7 +160,7 @@ func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
factory.sequenceToFilter,
factory.eventTypeFilter,
factory.resourceOwnerFilter,
factory.tenantFilter,
factory.instanceIDFilter,
factory.creationDateNewerFilter,
} {
if filter := f(); filter != nil {
@ -231,11 +232,11 @@ func (factory *SearchQueryFactory) resourceOwnerFilter() *Filter {
return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals)
}
func (factory *SearchQueryFactory) tenantFilter() *Filter {
if factory.tenant == "" {
func (factory *SearchQueryFactory) instanceIDFilter() *Filter {
if factory.instanceID == "" {
return nil
}
return NewFilter(Field_Tenant, factory.tenant, Operation_Equals)
return NewFilter(Field_InstanceID, factory.instanceID, Operation_Equals)
}
func (factory *SearchQueryFactory) creationDateNewerFilter() *Filter {

View File

@ -69,8 +69,8 @@ func (q *SearchQuery) ResourceOwnerFilter(resourceOwner string) *SearchQuery {
return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals))
}
func (q *SearchQuery) TenantFilter(tenant string) *SearchQuery {
return q.setFilter(NewFilter(Field_Tenant, tenant, Operation_Equals))
func (q *SearchQuery) InstanceIDFilter(instanceID string) *SearchQuery {
return q.setFilter(NewFilter(Field_InstanceID, instanceID, Operation_Equals))
}
func (q *SearchQuery) CreationDateNewerFilter(time time.Time) *SearchQuery {

View File

@ -54,7 +54,7 @@ func ReduceEvent(handler Handler, event *models.Event) {
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
if err != nil {
logging.LogWithFields("HANDL-L6YH1", "seq", event.Sequence).Warn("filter failed")
logging.WithFields("HANDL-L6YH1", "sequence", event.Sequence).Warn("filter failed")
return
}
@ -74,12 +74,12 @@ func ReduceEvent(handler Handler, event *models.Event) {
}
err = handler.Reduce(unprocessedEvent)
logging.LogWithFields("HANDL-V42TI", "seq", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
logging.WithFields("HANDL-V42TI", "sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
}
if len(unprocessedEvents) == eventLimit {
logging.LogWithFields("QUERY-BSqe9", "seq", event.Sequence).Warn("didnt process event")
logging.WithFields("QUERY-BSqe9", "sequence", event.Sequence).Warn("didnt process event")
return
}
err = handler.Reduce(event)
logging.LogWithFields("HANDL-wQDL2", "seq", event.Sequence).OnError(err).Warn("reduce failed")
logging.WithFields("HANDL-wQDL2", "sequence", event.Sequence).OnError(err).Warn("reduce failed")
}

View File

@ -10,7 +10,7 @@ type WriteModel struct {
ProcessedSequence uint64 `json:"-"`
Events []Event `json:"-"`
ResourceOwner string `json:"-"`
Tenant string `json:"-"`
InstanceID string `json:"-"`
ChangeDate time.Time `json:"-"`
}
@ -33,8 +33,8 @@ func (wm *WriteModel) Reduce() error {
if wm.ResourceOwner == "" {
wm.ResourceOwner = wm.Events[0].Aggregate().ResourceOwner
}
if wm.Tenant == "" {
wm.Tenant = wm.Events[0].Aggregate().Tenant
if wm.InstanceID == "" {
wm.InstanceID = wm.Events[0].Aggregate().InstanceID
}
wm.ProcessedSequence = wm.Events[len(wm.Events)-1].Sequence()

View File

@ -51,6 +51,7 @@ const (
IDPConfigSearchKeyAggregateID
IDPConfigSearchKeyIdpConfigID
IDPConfigSearchKeyIdpProviderType
IDPConfigSearchKeyInstanceID
)
type IDPConfigSearchQuery struct {

View File

@ -48,6 +48,7 @@ const (
LabelPolicySearchKeyUnspecified LabelPolicySearchKey = iota
LabelPolicySearchKeyAggregateID
LabelPolicySearchKeyState
LabelPolicySearchKeyInstanceID
)
type LabelPolicySearchQuery struct {

View File

@ -24,6 +24,7 @@ const (
IDPConfigKeyAggregateID = "aggregate_id"
IDPConfigKeyName = "name"
IDPConfigKeyProviderType = "idp_provider_type"
IDPConfigKeyInstanceID = "instance_id"
)
type IDPConfigView struct {
@ -50,7 +51,8 @@ type IDPConfigView struct {
JWTKeysEndpoint string `json:"keysEndpoint" gorm:"jwt_keys_endpoint"`
JWTHeaderName string `json:"headerName" gorm:"jwt_header_name"`
Sequence uint64 `json:"-" gorm:"column:sequence"`
Sequence uint64 `json:"-" gorm:"column:sequence"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
}
func IDPConfigViewToModel(idp *IDPConfigView) *model.IDPConfigView {
@ -120,6 +122,7 @@ func (i *IDPConfigView) AppendEvent(providerType model.IDPProviderType, event *m
func (r *IDPConfigView) setRootData(event *models.Event) {
r.AggregateID = event.AggregateID
r.InstanceID = event.InstanceID
}
func (r *IDPConfigView) SetData(event *models.Event) error {

View File

@ -59,6 +59,8 @@ func (key IDPConfigSearchKey) ToColumnName() string {
return IDPConfigKeyName
case iam_model.IDPConfigSearchKeyIdpProviderType:
return IDPConfigKeyProviderType
case iam_model.IDPConfigSearchKeyInstanceID:
return IDPConfigKeyInstanceID
default:
return ""
}

View File

@ -2,12 +2,14 @@ package model
import (
"encoding/json"
org_es_model "github.com/caos/zitadel/internal/org/repository/eventsourcing/model"
"time"
org_es_model "github.com/caos/zitadel/internal/org/repository/eventsourcing/model"
es_model "github.com/caos/zitadel/internal/iam/repository/eventsourcing/model"
"github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/iam/model"
@ -32,7 +34,8 @@ type IDPProviderView struct {
IDPProviderType int32 `json:"idpProviderType" gorm:"column:idp_provider_type"`
IDPState int32 `json:"-" gorm:"column:idp_state"`
Sequence uint64 `json:"-" gorm:"column:sequence"`
Sequence uint64 `json:"-" gorm:"column:sequence"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
}
func IDPProviderViewFromModel(provider *model.IDPProviderView) *IDPProviderView {
@ -87,6 +90,7 @@ func (i *IDPProviderView) AppendEvent(event *models.Event) (err error) {
func (r *IDPProviderView) setRootData(event *models.Event) {
r.AggregateID = event.AggregateID
r.InstanceID = event.InstanceID
}
func (r *IDPProviderView) SetData(event *models.Event) error {

View File

@ -19,6 +19,7 @@ import (
const (
LabelPolicyKeyAggregateID = "aggregate_id"
LabelPolicyKeyState = "label_policy_state"
LabelPolicyKeyInstanceID = "instance_id"
)
type LabelPolicyView struct {
@ -45,7 +46,8 @@ type LabelPolicyView struct {
DisableWatermark bool `json:"disableWatermark" gorm:"column:disable_watermark"`
Default bool `json:"-" gorm:"-"`
Sequence uint64 `json:"-" gorm:"column:sequence"`
Sequence uint64 `json:"-" gorm:"column:sequence"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
}
type AssetView struct {
@ -189,6 +191,7 @@ func (i *LabelPolicyView) AppendEvent(event *models.Event) (err error) {
func (r *LabelPolicyView) setRootData(event *models.Event) {
r.AggregateID = event.AggregateID
r.InstanceID = event.InstanceID
}
func (r *LabelPolicyView) SetData(event *models.Event) error {

View File

@ -55,6 +55,9 @@ func (key LabelPolicySearchKey) ToColumnName() string {
return LabelPolicyKeyAggregateID
case iam_model.LabelPolicySearchKeyState:
return LabelPolicyKeyState
case iam_model.LabelPolicySearchKeyInstanceID:
return LabelPolicyKeyInstanceID
default:
return ""
}

View File

@ -0,0 +1,77 @@
package migration
import "github.com/caos/zitadel/internal/eventstore"
//SetupStep is the command pushed on the eventstore
type SetupStep struct {
typ eventstore.EventType
migration Migration
Name string `json:"name"`
Error error `json:"error,omitempty"`
done bool
}
func setupStartedCmd(migration Migration) eventstore.Command {
return &SetupStep{
migration: migration,
typ: startedType,
Name: migration.String(),
}
}
func setupDoneCmd(migration Migration, err error) eventstore.Command {
s := &SetupStep{
typ: doneType,
migration: migration,
Name: migration.String(),
}
if err != nil {
s.typ = failedType
s.Error = err
}
return s
}
func (s *SetupStep) Aggregate() eventstore.Aggregate {
return eventstore.Aggregate{
ID: aggregateID,
Type: aggregateType,
ResourceOwner: "SYSTEM",
Version: "v1",
}
}
func (s *SetupStep) EditorService() string {
return "system"
}
func (s *SetupStep) EditorUser() string {
return "system"
}
func (s *SetupStep) Type() eventstore.EventType {
return s.typ
}
func (s *SetupStep) Data() interface{} {
return s
}
func (s *SetupStep) UniqueConstraints() []*eventstore.EventUniqueConstraint {
switch s.typ {
case startedType:
return []*eventstore.EventUniqueConstraint{
eventstore.NewAddEventUniqueConstraint("migration_started", s.migration.String(), "Errors.Step.Started.AlreadyExists"),
}
case failedType:
return []*eventstore.EventUniqueConstraint{
eventstore.NewRemoveEventUniqueConstraint("migration_started", s.migration.String()),
}
default:
return []*eventstore.EventUniqueConstraint{
eventstore.NewAddEventUniqueConstraint("migration_done", s.migration.String(), "Errors.Step.Done.AlreadyExists"),
}
}
}

View File

@ -0,0 +1,84 @@
package migration
import (
"context"
"encoding/json"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
)
const (
startedType = eventstore.EventType("system.migration.started")
doneType = eventstore.EventType("system.migration.done")
failedType = eventstore.EventType("system.migration.failed")
aggregateType = eventstore.AggregateType("system")
aggregateID = "SYSTEM"
)
type Migration interface {
String() string
Execute(context.Context) error
}
func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) {
if should, err := shouldExec(ctx, es, migration); !should || err != nil {
return err
}
if _, err = es.Push(ctx, setupStartedCmd(migration)); err != nil {
return err
}
err = migration.Execute(ctx)
logging.OnError(err).Error("migration failed")
_, err = es.Push(ctx, setupDoneCmd(migration, err))
return err
}
func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migration) (should bool, err error) {
events, err := es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AddQuery().
AggregateTypes(aggregateType).
AggregateIDs(aggregateID).
EventTypes(startedType, doneType, failedType).
Builder())
if err != nil {
return false, err
}
if len(events) == 0 {
return true, nil
}
if events[len(events)-1].Type() == startedType {
return false, nil
}
for _, e := range events {
step := new(SetupStep)
err = json.Unmarshal(e.DataAsBytes(), step)
if err != nil {
return false, err
}
if step.Name != migration.String() {
continue
}
switch e.Type() {
case startedType, doneType:
//TODO: if started should we wait until done/failed?
return false, nil
case failedType:
//TODO: how to allow retries?
logging.WithFields("migration", migration.String()).Error("failed before")
return false, errors.ThrowInternal(nil, "MIGRA-mjI2E", "migration failed before")
}
}
return true, nil
}

View File

@ -26,6 +26,7 @@ const (
OrgProjectMappingSearchKeyProjectID
OrgProjectMappingSearchKeyOrgID
OrgProjectMappingSearchKeyProjectGrantID
OrgProjectMappingSearchKeyInstanceID
)
type OrgProjectMappingViewSearchQuery struct {

View File

@ -4,10 +4,12 @@ const (
OrgProjectMappingKeyProjectID = "project_id"
OrgProjectMappingKeyOrgID = "org_id"
OrgProjectMappingKeyProjectGrantID = "project_grant_id"
OrgProjectMappingKeyInstanceID = "instance_id"
)
type OrgProjectMapping struct {
ProjectID string `json:"-" gorm:"column:project_id;primary_key"`
OrgID string `json:"-" gorm:"column:org_id;primary_key"`
ProjectGrantID string `json:"-" gorm:"column:project_grant_id;"`
ProjectGrantID string `json:"-" gorm:"column:project_grant_id"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
}

View File

@ -57,6 +57,8 @@ func (key OrgProjectMappingSearchKey) ToColumnName() string {
return OrgProjectMappingKeyProjectID
case proj_model.OrgProjectMappingSearchKeyProjectGrantID:
return OrgProjectMappingKeyProjectGrantID
case proj_model.OrgProjectMappingSearchKeyInstanceID:
return OrgProjectMappingKeyInstanceID
default:
return ""
}

View File

@ -2,13 +2,15 @@ package model
import (
"encoding/json"
"time"
"github.com/caos/logging"
"github.com/lib/pq"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/project/model"
es_model "github.com/caos/zitadel/internal/project/repository/eventsourcing/model"
"github.com/lib/pq"
"time"
)
const (
@ -39,6 +41,7 @@ type ProjectGrant struct {
GrantID string `json:"grantId"`
GrantedOrgID string `json:"grantedOrgId"`
RoleKeys []string `json:"roleKeys"`
InstanceID string `json:"instanceID"`
}
func ProjectGrantFromModel(project *model.ProjectGrantView) *ProjectGrantView {

View File

@ -8,6 +8,8 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query/projection"
@ -33,6 +35,10 @@ var (
name: projection.ActionResourceOwnerCol,
table: actionTable,
}
ActionColumnInstanceID = Column{
name: projection.ActionInstanceIDCol,
table: actionTable,
}
ActionColumnSequence = Column{
name: projection.ActionSequenceCol,
table: actionTable,
@ -93,7 +99,11 @@ func (q *ActionSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQueries) (actions *Actions, err error) {
query, scan := prepareActionsQuery()
stmt, args, err := queries.toQuery(query).ToSql()
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
ActionColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}).
ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-SDgwg", "Errors.Query.InvalidRequest")
}
@ -116,6 +126,7 @@ func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string) (*
sq.Eq{
ActionColumnID.identifier(): id,
ActionColumnResourceOwner.identifier(): orgID,
ActionColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement")

View File

@ -8,6 +8,8 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/lib/pq"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query/projection"
@ -21,6 +23,14 @@ var (
name: projection.FlowTypeCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnChangeDate = Column{
name: projection.FlowChangeDateCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnSequence = Column{
name: projection.FlowSequenceCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnTriggerType = Column{
name: projection.FlowTriggerTypeCol,
table: flowsTriggersTable,
@ -29,6 +39,10 @@ var (
name: projection.FlowResourceOwnerCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnInstanceID = Column{
name: projection.FlowInstanceIDCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnTriggerSequence = Column{
name: projection.FlowActionTriggerSequenceCol,
table: flowsTriggersTable,
@ -40,10 +54,9 @@ var (
)
type Flow struct {
CreationDate time.Time //TODO: add in projection
ChangeDate time.Time //TODO: add in projection
ResourceOwner string //TODO: add in projection
Sequence uint64 //TODO: add in projection
ChangeDate time.Time
ResourceOwner string
Sequence uint64
Type domain.FlowType
TriggerActions map[domain.TriggerType][]*Action
@ -55,6 +68,7 @@ func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID s
sq.Eq{
FlowsTriggersColumnFlowType.identifier(): flowType,
FlowsTriggersColumnResourceOwner.identifier(): orgID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-HBRh3", "Errors.Query.InvalidRequest")
@ -74,6 +88,7 @@ func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flow
FlowsTriggersColumnFlowType.identifier(): flowType,
FlowsTriggersColumnTriggerType.identifier(): triggerType,
FlowsTriggersColumnResourceOwner.identifier(): orgID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
ActionColumnState.identifier(): domain.ActionStateActive,
},
).ToSql()
@ -92,7 +107,8 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string) (
stmt, scan := prepareFlowTypesQuery()
query, args, err := stmt.Where(
sq.Eq{
FlowsTriggersColumnActionID.identifier(): actionID,
FlowsTriggersColumnActionID.identifier(): actionID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
},
).ToSql()
if err != nil {
@ -185,6 +201,9 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
FlowsTriggersColumnTriggerType.identifier(),
FlowsTriggersColumnTriggerSequence.identifier(),
FlowsTriggersColumnFlowType.identifier(),
FlowsTriggersColumnChangeDate.identifier(),
FlowsTriggersColumnSequence.identifier(),
FlowsTriggersColumnResourceOwner.identifier(),
).
From(flowsTriggersTable.name).
LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)).
@ -194,7 +213,6 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
TriggerActions: make(map[domain.TriggerType][]*Action),
}
for rows.Next() {
// action := new(Action)
var (
actionID sql.NullString
actionCreationDate pq.NullTime
@ -207,7 +225,6 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
triggerType domain.TriggerType
triggerSequence int
flowType domain.FlowType
)
err := rows.Scan(
&actionID,
@ -220,12 +237,14 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
&actionScript,
&triggerType,
&triggerSequence,
&flowType,
&flow.Type,
&flow.ChangeDate,
&flow.Sequence,
&flow.ResourceOwner,
)
if err != nil {
return nil, err
}
flow.Type = flowType
if !actionID.Valid {
continue
}

Some files were not shown because too many files have changed in this diff Show More