feat: projections auto create their tables (#3324)

* begin init checks for projections

* first projection checks

* debug notification providers with query fixes

* more projections and first index

* more projections

* more projections

* finish projections

* fix tests (remove db name)

* create tables in setup

* fix logging / error handling

* add tenant to views

* rename tenant to instance_id

* add instance_id to all projections

* add instance_id to all queries

* correct instance_id on projections

* add instance_id to failed_events

* use separate context for instance

* implement features projection

* implement features projection

* remove unique constraint from setup when migration failed

* add error to failed setup event

* add instance_id to primary keys

* fix IAM projection

* remove old migrations folder

* fix keysFromYAML test
This commit is contained in:
Livio Amstutz 2022-03-23 09:02:39 +01:00 committed by GitHub
parent 9e13b70a3d
commit 56b916a2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
400 changed files with 6508 additions and 8890 deletions

View File

@ -12,13 +12,13 @@ CREATE TABLE eventstore.events (
, editor_user TEXT NOT NULL , editor_user TEXT NOT NULL
, editor_service TEXT NOT NULL , editor_service TEXT NOT NULL
, resource_owner TEXT NOT NULL , resource_owner TEXT NOT NULL
, tenant TEXT , instance_id TEXT
, PRIMARY KEY (event_sequence DESC) USING HASH WITH BUCKET_COUNT = 10 , PRIMARY KEY (event_sequence DESC) USING HASH WITH BUCKET_COUNT = 10
, INDEX agg_type_agg_id (aggregate_type, aggregate_id) , INDEX agg_type_agg_id (aggregate_type, aggregate_id)
, INDEX agg_type (aggregate_type) , INDEX agg_type (aggregate_type)
, INDEX agg_type_seq (aggregate_type, event_sequence DESC) , INDEX agg_type_seq (aggregate_type, event_sequence DESC)
STORING (id, event_type, aggregate_id, aggregate_version, previous_aggregate_sequence, creation_date, event_data, editor_user, editor_service, resource_owner, tenant, previous_aggregate_type_sequence) STORING (id, event_type, aggregate_id, aggregate_version, previous_aggregate_sequence, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_type_sequence)
, INDEX max_sequence (aggregate_type, aggregate_id, event_sequence DESC) , INDEX max_sequence (aggregate_type, aggregate_id, event_sequence DESC)
, CONSTRAINT previous_sequence_unique UNIQUE (previous_aggregate_sequence DESC) , CONSTRAINT previous_sequence_unique UNIQUE (previous_aggregate_sequence DESC)
, CONSTRAINT prev_agg_type_seq_unique UNIQUE(previous_aggregate_type_sequence) , CONSTRAINT prev_agg_type_seq_unique UNIQUE(previous_aggregate_type_sequence)

View File

@ -155,7 +155,7 @@ func Test_keysFromYAML(t *testing.T) {
if tt.res.err != nil && !tt.res.err(err) { if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err) t.Errorf("got wrong err: %v ", err)
} }
assert.EqualValues(t, got, tt.res.keys) assert.ElementsMatch(t, got, tt.res.keys)
}) })
} }
} }

View File

@ -0,0 +1,34 @@
package setup
import (
"context"
"database/sql"
_ "embed"
)
var (
//go:embed 01_sql/adminapi.sql
createAdminViews string
//go:embed 01_sql/auth.sql
createAuthViews string
//go:embed 01_sql/authz.sql
createAuthzViews string
//go:embed 01_sql/notification.sql
createNotificationViews string
//go:embed 01_sql/projections.sql
createProjections string
)
type ProjectionTable struct {
dbClient *sql.DB
}
func (mig *ProjectionTable) Execute(ctx context.Context) error {
stmt := createAdminViews + createAuthViews + createAuthzViews + createNotificationViews + createProjections
_, err := mig.dbClient.ExecContext(ctx, stmt)
return err
}
func (mig *ProjectionTable) String() string {
return "01_tables"
}

View File

@ -0,0 +1,54 @@
CREATE SCHEMA adminapi;
CREATE TABLE adminapi.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE adminapi.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE adminapi.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE adminapi.styling (
aggregate_id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
label_policy_state INT2 NOT NULL DEFAULT 0:::INT2,
sequence INT8 NULL,
primary_color STRING NULL,
background_color STRING NULL,
warn_color STRING NULL,
font_color STRING NULL,
primary_color_dark STRING NULL,
background_color_dark STRING NULL,
warn_color_dark STRING NULL,
font_color_dark STRING NULL,
logo_url STRING NULL,
icon_url STRING NULL,
logo_dark_url STRING NULL,
icon_dark_url STRING NULL,
font_url STRING NULL,
err_msg_popup BOOL NULL,
disable_watermark BOOL NULL,
hide_login_name_suffix BOOL NULL,
instance_id STRING NOT NULL,
PRIMARY KEY (aggregate_id, label_policy_state)
);

View File

@ -0,0 +1,222 @@
CREATE SCHEMA auth;
CREATE TABLE auth.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE auth.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE auth.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE auth.users (
id STRING NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
user_state INT2 NULL,
password_set BOOL NULL,
password_change_required BOOL NULL,
password_change TIMESTAMPTZ NULL,
last_login TIMESTAMPTZ NULL,
user_name STRING NULL,
login_names STRING[] NULL,
preferred_login_name STRING NULL,
first_name STRING NULL,
last_name STRING NULL,
nick_name STRING NULL,
display_name STRING NULL,
preferred_language STRING NULL,
gender INT2 NULL,
email STRING NULL,
is_email_verified BOOL NULL,
phone STRING NULL,
is_phone_verified BOOL NULL,
country STRING NULL,
locality STRING NULL,
postal_code STRING NULL,
region STRING NULL,
street_address STRING NULL,
otp_state INT2 NULL,
mfa_max_set_up INT2 NULL,
mfa_init_skipped TIMESTAMPTZ NULL,
sequence INT8 NULL,
init_required BOOL NULL,
username_change_required BOOL NULL,
machine_name STRING NULL,
machine_description STRING NULL,
user_type STRING NULL,
u2f_tokens BYTES NULL,
passwordless_tokens BYTES NULL,
avatar_key STRING NULL,
passwordless_init_required BOOL NULL,
password_init_required BOOL NULL,
instance_id STRING NULL,
PRIMARY KEY (id)
);
CREATE TABLE auth.user_sessions (
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
state INT2 NULL,
user_agent_id STRING NULL,
user_id STRING NULL,
user_name STRING NULL,
password_verification TIMESTAMPTZ NULL,
second_factor_verification TIMESTAMPTZ NULL,
multi_factor_verification TIMESTAMPTZ NULL,
sequence INT8 NULL,
second_factor_verification_type INT2 NULL,
multi_factor_verification_type INT2 NULL,
user_display_name STRING NULL,
login_name STRING NULL,
external_login_verification TIMESTAMPTZ NULL,
selected_idp_config_id STRING NULL,
passwordless_verification TIMESTAMPTZ NULL,
avatar_key STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (user_agent_id, user_id)
);
CREATE TABLE auth.user_external_idps (
external_user_id STRING NOT NULL,
idp_config_id STRING NOT NULL,
user_id STRING NULL,
idp_name STRING NULL,
user_display_name STRING NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
resource_owner STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (external_user_id, idp_config_id)
);
CREATE TABLE auth.tokens (
id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
application_id STRING NULL,
user_agent_id STRING NULL,
user_id STRING NULL,
expiration TIMESTAMPTZ NULL,
sequence INT8 NULL,
scopes STRING[] NULL,
audience STRING[] NULL,
preferred_language STRING NULL,
refresh_token_id STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (id),
INDEX user_user_agent_idx (user_id, user_agent_id)
);
CREATE TABLE auth.refresh_tokens (
id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
token STRING NULL,
client_id STRING NOT NULL,
user_agent_id STRING NOT NULL,
user_id STRING NOT NULL,
auth_time TIMESTAMPTZ NULL,
idle_expiration TIMESTAMPTZ NULL,
expiration TIMESTAMPTZ NULL,
sequence INT8 NULL,
scopes STRING[] NULL,
audience STRING[] NULL,
amr STRING[] NULL,
instance_id STRING NULL,
PRIMARY KEY (id),
UNIQUE INDEX unique_client_user_index (client_id ASC, user_agent_id ASC, user_id ASC)
);
CREATE TABLE auth.org_project_mapping (
org_id STRING NOT NULL,
project_id STRING NOT NULL,
project_grant_id STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (org_id, project_id)
);
CREATE TABLE auth.idp_providers (
aggregate_id STRING NOT NULL,
idp_config_id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
name STRING NULL,
idp_config_type INT2 NULL,
idp_provider_type INT2 NULL,
idp_state INT2 NULL,
styling_type INT2 NULL,
instance_id STRING NULL,
PRIMARY KEY (aggregate_id, idp_config_id)
);
CREATE TABLE auth.idp_configs (
idp_config_id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
aggregate_id STRING NULL,
name STRING NULL,
idp_state INT2 NULL,
idp_provider_type INT2 NULL,
is_oidc BOOL NULL,
oidc_client_id STRING NULL,
oidc_client_secret JSONB NULL,
oidc_issuer STRING NULL,
oidc_scopes STRING[] NULL,
oidc_idp_display_name_mapping INT2 NULL,
oidc_idp_username_mapping INT2 NULL,
styling_type INT2 NULL,
oauth_authorization_endpoint STRING NULL,
oauth_token_endpoint STRING NULL,
auto_register BOOL NULL,
jwt_endpoint STRING NULL,
jwt_keys_endpoint STRING NULL,
jwt_header_name STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (idp_config_id)
);
CREATE TABLE auth.auth_requests (
id STRING NOT NULL,
request JSONB NULL,
code STRING NULL,
request_type INT2 NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
instance_id STRING NULL,
PRIMARY KEY (id),
INDEX auth_code_idx (code)
);

View File

@ -0,0 +1,44 @@
CREATE SCHEMA authz;
CREATE TABLE authz.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE authz.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE authz.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE authz.user_memberships (
user_id STRING NOT NULL,
member_type INT2 NOT NULL,
aggregate_id STRING NOT NULL,
object_id STRING NOT NULL,
roles STRING[] NULL,
display_name STRING NULL,
resource_owner STRING NULL,
resource_owner_name STRING NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
sequence INT8 NULL,
instance_id STRING NULL,
PRIMARY KEY (user_id, member_type, aggregate_id, object_id)
);

View File

@ -0,0 +1,52 @@
CREATE SCHEMA notification;
CREATE TABLE notification.locks (
locker_id TEXT,
locked_until TIMESTAMPTZ(3),
projection_name TEXT,
PRIMARY KEY (projection_name)
);
CREATE TABLE notification.current_sequences (
projection_name TEXT,
aggregate_type TEXT,
current_sequence BIGINT,
timestamp TIMESTAMPTZ,
PRIMARY KEY (projection_name, aggregate_type)
);
CREATE TABLE notification.failed_events (
projection_name TEXT,
failed_sequence BIGINT,
failure_count SMALLINT,
error TEXT,
PRIMARY KEY (projection_name, failed_sequence)
);
CREATE TABLE notification.notify_users (
id STRING NOT NULL,
creation_date TIMESTAMPTZ NULL,
change_date TIMESTAMPTZ NULL,
resource_owner STRING NULL,
user_name STRING NULL,
first_name STRING NULL,
last_name STRING NULL,
nick_name STRING NULL,
display_name STRING NULL,
preferred_language STRING NULL,
gender INT2 NULL,
last_email STRING NULL,
verified_email STRING NULL,
last_phone STRING NULL,
verified_phone STRING NULL,
sequence INT8 NULL,
password_set BOOL NULL,
login_names STRING NULL,
preferred_login_name STRING NULL,
instance_id STRING NULL,
PRIMARY KEY (id)
);

View File

@ -1,10 +1,4 @@
CREATE DATABASE zitadel; CREATE TABLE projections.locks (
GRANT SELECT, INSERT, UPDATE, DELETE ON DATABASE zitadel TO queries;
use zitadel;
CREATE SCHEMA zitadel.projections AUTHORIZATION queries;
CREATE TABLE zitadel.projections.locks (
locker_id TEXT, locker_id TEXT,
locked_until TIMESTAMPTZ(3), locked_until TIMESTAMPTZ(3),
projection_name TEXT, projection_name TEXT,
@ -12,7 +6,7 @@ CREATE TABLE zitadel.projections.locks (
PRIMARY KEY (projection_name) PRIMARY KEY (projection_name)
); );
CREATE TABLE zitadel.projections.current_sequences ( CREATE TABLE projections.current_sequences (
projection_name TEXT, projection_name TEXT,
aggregate_type TEXT, aggregate_type TEXT,
current_sequence BIGINT, current_sequence BIGINT,
@ -21,11 +15,12 @@ CREATE TABLE zitadel.projections.current_sequences (
PRIMARY KEY (projection_name, aggregate_type) PRIMARY KEY (projection_name, aggregate_type)
); );
CREATE TABLE zitadel.projections.failed_events ( CREATE TABLE projections.failed_events (
projection_name TEXT, projection_name TEXT,
failed_sequence BIGINT, failed_sequence BIGINT,
failure_count SMALLINT, failure_count SMALLINT,
error TEXT, error TEXT,
instance_id TEXT,
PRIMARY KEY (projection_name, failed_sequence) PRIMARY KEY (projection_name, failed_sequence, instance_id)
); );

13
cmd/admin/setup/config.go Normal file
View File

@ -0,0 +1,13 @@
package setup
import (
"github.com/caos/zitadel/internal/database"
)
type Config struct {
Database database.Config
}
type Steps struct {
S1ProjectionTable *ProjectionTable
}

View File

@ -1,10 +1,22 @@
package setup package setup
import ( import (
"bytes"
"context"
_ "embed" _ "embed"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/caos/zitadel/internal/database"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/migration"
)
var (
//go:embed steps.yaml
defaultSteps []byte
) )
func New() *cobra.Command { func New() *cobra.Command {
@ -14,9 +26,33 @@ func New() *cobra.Command {
Long: `sets up data to start ZITADEL. Long: `sets up data to start ZITADEL.
Requirements: Requirements:
- cockroachdb`, - cockroachdb`,
RunE: func(cmd *cobra.Command, args []string) error { Run: func(cmd *cobra.Command, args []string) {
logging.Info("hello world") config := new(Config)
return nil err := viper.Unmarshal(config)
logging.OnError(err).Fatal("unable to read config")
v := viper.New()
v.SetConfigType("yaml")
err = v.ReadConfig(bytes.NewBuffer(defaultSteps))
logging.OnError(err).Fatal("unable to read setup steps")
steps := new(Steps)
err = v.Unmarshal(steps)
logging.OnError(err).Fatal("unable to read steps")
setup(config, steps)
}, },
} }
} }
func setup(config *Config, steps *Steps) {
dbClient, err := database.Connect(config.Database)
logging.OnError(err).Fatal("unable to connect to database")
eventstoreClient, err := eventstore.Start(dbClient)
logging.OnError(err).Fatal("unable to start eventstore")
steps.S1ProjectionTable = &ProjectionTable{dbClient: dbClient}
migration.Migrate(context.Background(), eventstoreClient, steps.S1ProjectionTable)
}

View File

@ -41,7 +41,6 @@ import (
"github.com/caos/zitadel/internal/crypto" "github.com/caos/zitadel/internal/crypto"
cryptoDB "github.com/caos/zitadel/internal/crypto/database" cryptoDB "github.com/caos/zitadel/internal/crypto/database"
"github.com/caos/zitadel/internal/database" "github.com/caos/zitadel/internal/database"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/id" "github.com/caos/zitadel/internal/id"
"github.com/caos/zitadel/internal/notification" "github.com/caos/zitadel/internal/notification"
@ -308,7 +307,7 @@ func shutdownServer(ctx context.Context, server *http.Server) error {
//TODO:!!??!! //TODO:!!??!!
func consoleClientID(ctx context.Context, queries *query.Queries) (string, error) { func consoleClientID(ctx context.Context, queries *query.Queries) (string, error) {
iam, err := queries.IAMByID(ctx, domain.IAMID) iam, err := queries.IAM(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -32,7 +32,7 @@ type API struct {
type health interface { type health interface {
Health(ctx context.Context) error Health(ctx context.Context) error
IAMByID(ctx context.Context, id string) (*query.IAM, error) IAM(ctx context.Context) (*query.IAM, error)
} }
func New( func New(
@ -107,7 +107,7 @@ func (a *API) healthHandler() http.Handler {
return nil return nil
}, },
func(ctx context.Context) error { func(ctx context.Context) error {
iam, err := a.health.IAMByID(ctx, domain.IAMID) iam, err := a.health.IAM(ctx)
if err != nil && !errors.IsNotFound(err) { if err != nil && !errors.IsNotFound(err) {
return errors.ThrowPreconditionFailed(err, "API-dsgT2", "IAM SETUP CHECK FAILED") return errors.ThrowPreconditionFailed(err, "API-dsgT2", "IAM SETUP CHECK FAILED")
} }

View File

@ -15,12 +15,12 @@ const (
requestPermissionsKey key = 1 requestPermissionsKey key = 1
dataKey key = 2 dataKey key = 2
allPermissionsKey key = 3 allPermissionsKey key = 3
instanceKey key = 4
) )
type CtxData struct { type CtxData struct {
UserID string UserID string
OrgID string OrgID string
TenantID string //TODO: Set Tenant ID on some context
ProjectID string ProjectID string
AgentID string AgentID string
PreferredLanguage string PreferredLanguage string
@ -31,6 +31,10 @@ func (ctxData CtxData) IsZero() bool {
return ctxData.UserID == "" || ctxData.OrgID == "" return ctxData.UserID == "" || ctxData.OrgID == ""
} }
type Instance struct {
ID string
}
type Grants []*Grant type Grants []*Grant
type Grant struct { type Grant struct {
@ -43,7 +47,7 @@ type Memberships []*Membership
type Membership struct { type Membership struct {
MemberType MemberType MemberType MemberType
AggregateID string AggregateID string
//ObjectID differs from aggregate id if obejct is sub of an aggregate //ObjectID differs from aggregate id if object is sub of an aggregate
ObjectID string ObjectID string
Roles []string Roles []string
@ -112,6 +116,11 @@ func GetCtxData(ctx context.Context) CtxData {
return ctxData return ctxData
} }
func GetInstance(ctx context.Context) Instance {
instance, _ := ctx.Value(instanceKey).(Instance)
return instance
}
func GetRequestPermissionsFromCtx(ctx context.Context) []string { func GetRequestPermissionsFromCtx(ctx context.Context) []string {
ctxPermission, _ := ctx.Value(requestPermissionsKey).([]string) ctxPermission, _ := ctx.Value(requestPermissionsKey).([]string)
return ctxPermission return ctxPermission

View File

@ -2,11 +2,13 @@ package authz
import "context" import "context"
func NewMockContext(tenantID, orgID, userID string) context.Context { func NewMockContext(instanceID, orgID, userID string) context.Context {
return context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID, TenantID: tenantID}) ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
return context.WithValue(ctx, instanceKey, instanceID)
} }
func NewMockContextWithPermissions(tenantID, orgID, userID string, permissions []string) context.Context { func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context {
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID, TenantID: tenantID}) ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
ctx = context.WithValue(ctx, instanceKey, instanceID)
return context.WithValue(ctx, requestPermissionsKey, permissions) return context.WithValue(ctx, requestPermissionsKey, permissions)
} }

View File

@ -3,12 +3,12 @@ package admin
import ( import (
"context" "context"
"golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/grpc/object" "github.com/caos/zitadel/internal/api/grpc/object"
"github.com/caos/zitadel/internal/api/grpc/text" "github.com/caos/zitadel/internal/api/grpc/text"
"github.com/caos/zitadel/internal/domain"
caos_errors "github.com/caos/zitadel/internal/errors" caos_errors "github.com/caos/zitadel/internal/errors"
admin_pb "github.com/caos/zitadel/pkg/grpc/admin" admin_pb "github.com/caos/zitadel/pkg/grpc/admin"
"golang.org/x/text/language"
) )
func (s *Server) GetSupportedLanguages(ctx context.Context, req *admin_pb.GetSupportedLanguagesRequest) (*admin_pb.GetSupportedLanguagesResponse, error) { func (s *Server) GetSupportedLanguages(ctx context.Context, req *admin_pb.GetSupportedLanguagesRequest) (*admin_pb.GetSupportedLanguagesResponse, error) {
@ -34,9 +34,5 @@ func (s *Server) SetDefaultLanguage(ctx context.Context, req *admin_pb.SetDefaul
} }
func (s *Server) GetDefaultLanguage(ctx context.Context, req *admin_pb.GetDefaultLanguageRequest) (*admin_pb.GetDefaultLanguageResponse, error) { func (s *Server) GetDefaultLanguage(ctx context.Context, req *admin_pb.GetDefaultLanguageRequest) (*admin_pb.GetDefaultLanguageResponse, error) {
iam, err := s.query.IAMByID(ctx, domain.IAMID) return &admin_pb.GetDefaultLanguageResponse{Language: s.query.GetDefaultLanguage(ctx).String()}, nil
if err != nil {
return nil, err
}
return &admin_pb.GetDefaultLanguageResponse{Language: iam.DefaultLanguage.String()}, nil
} }

View File

@ -152,7 +152,7 @@ func (s *Server) ListMyProjectOrgs(ctx context.Context, req *auth_pb.ListMyProje
return nil, err return nil, err
} }
iam, err := s.query.IAMByID(ctx, domain.IAMID) iam, err := s.query.IAM(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -3,12 +3,11 @@ package management
import ( import (
"context" "context"
"github.com/caos/zitadel/internal/domain"
mgmt_pb "github.com/caos/zitadel/pkg/grpc/management" mgmt_pb "github.com/caos/zitadel/pkg/grpc/management"
) )
func (s *Server) GetIAM(ctx context.Context, req *mgmt_pb.GetIAMRequest) (*mgmt_pb.GetIAMResponse, error) { func (s *Server) GetIAM(ctx context.Context, _ *mgmt_pb.GetIAMRequest) (*mgmt_pb.GetIAMResponse, error) {
iam, err := s.query.IAMByID(ctx, domain.IAMID) iam, err := s.query.IAM(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -206,8 +206,8 @@ func (s *Server) SetPrimaryOrgDomain(ctx context.Context, req *mgmt_pb.SetPrimar
}, nil }, nil
} }
func (s *Server) ListOrgMemberRoles(ctx context.Context, req *mgmt_pb.ListOrgMemberRolesRequest) (*mgmt_pb.ListOrgMemberRolesResponse, error) { func (s *Server) ListOrgMemberRoles(ctx context.Context, _ *mgmt_pb.ListOrgMemberRolesRequest) (*mgmt_pb.ListOrgMemberRolesResponse, error) {
iam, err := s.query.IAMByID(ctx, domain.IAMID) iam, err := s.query.IAM(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -21,6 +21,10 @@ func UserAgentIDFromCtx(ctx context.Context) (string, bool) {
return userAgentID, ok return userAgentID, ok
} }
func InstanceIDFromCtx(ctx context.Context) string {
return "" //TODO: implement
}
type UserAgent struct { type UserAgent struct {
ID string ID string
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/api/http/middleware" "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query" "github.com/caos/zitadel/internal/query"
@ -45,7 +46,8 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
if !ok { if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id") return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
} }
resp, err := o.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID) instanceID := authz.GetInstance(ctx).ID
resp, err := o.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,7 +57,9 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) { func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
resp, err := o.repo.AuthRequestByCode(ctx, code)
instanceID := authz.GetInstance(ctx).ID
resp, err := o.repo.AuthRequestByCode(ctx, code, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,13 +73,16 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro
if !ok { if !ok {
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id") return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
} }
return o.repo.SaveAuthCode(ctx, id, code, userAgentID) instanceID := authz.GetInstance(ctx).ID
return o.repo.SaveAuthCode(ctx, id, code, userAgentID, instanceID)
} }
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) { func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
return o.repo.DeleteAuthRequest(ctx, id)
instanceID := authz.GetInstance(ctx).ID
return o.repo.DeleteAuthRequest(ctx, id, instanceID)
} }
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) { func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {

View File

@ -10,6 +10,7 @@ import (
"github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/authz"
http_utils "github.com/caos/zitadel/internal/api/http" http_utils "github.com/caos/zitadel/internal/api/http"
model2 "github.com/caos/zitadel/internal/auth_request/model" model2 "github.com/caos/zitadel/internal/auth_request/model"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
@ -132,6 +133,7 @@ func CreateAuthRequestToBusiness(ctx context.Context, authReq *oidc.AuthRequest,
SelectedIDPConfigID: GetSelectedIDPIDFromScopes(authReq.Scopes), SelectedIDPConfigID: GetSelectedIDPIDFromScopes(authReq.Scopes),
MaxAuthAge: MaxAgeToBusiness(authReq.MaxAge), MaxAuthAge: MaxAgeToBusiness(authReq.MaxAge),
UserID: userID, UserID: userID,
InstanceID: authz.GetInstance(ctx).ID,
Request: &domain.AuthRequestOIDC{ Request: &domain.AuthRequestOIDC{
Scopes: authReq.Scopes, Scopes: authReq.Scopes,
ResponseType: ResponseTypeToBusiness(authReq.ResponseType), ResponseType: ResponseTypeToBusiness(authReq.ResponseType),

View File

@ -5,6 +5,7 @@ import (
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
) )
@ -19,7 +20,8 @@ func (l *Login) getAuthRequest(r *http.Request) (*domain.AuthRequest, error) {
return nil, nil return nil, nil
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
return l.authRepo.AuthRequestByID(r.Context(), authRequestID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
return l.authRepo.AuthRequestByID(r.Context(), authRequestID, userAgentID, instanceID)
} }
func (l *Login) getAuthRequestAndParseData(r *http.Request, data interface{}) (*domain.AuthRequest, error) { func (l *Login) getAuthRequestAndParseData(r *http.Request, data interface{}) (*domain.AuthRequest, error) {

View File

@ -16,7 +16,7 @@ func (l *Login) customExternalUserMapping(ctx context.Context, user *domain.Exte
resourceOwner = config.AggregateID resourceOwner = config.AggregateID
} }
if resourceOwner == domain.IAMID { if resourceOwner == domain.IAMID {
iam, err := l.query.IAMByID(ctx, domain.IAMID) iam, err := l.query.IAM(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -11,6 +11,8 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/crypto" "github.com/caos/zitadel/internal/crypto"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
@ -87,7 +89,8 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderLogin(w, r, authReq, err) l.renderLogin(w, r, authReq, err)
return return
@ -139,7 +142,8 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -198,12 +202,13 @@ func (l *Login) handleExternalUserAuthenticated(w http.ResponseWriter, r *http.R
return return
} }
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, userAgentID, externalUser, domain.BrowserInfoFromRequest(r)) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, userAgentID, instanceID, externalUser, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
if errors.IsNotFound(err) { if errors.IsNotFound(err) {
err = nil err = nil
} }
iam, err := l.query.IAMByID(r.Context(), domain.IAMID) iam, err := l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err) l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
return return
@ -226,7 +231,7 @@ func (l *Login) handleExternalUserAuthenticated(w http.ResponseWriter, r *http.R
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIAMPolicy, human, idpLinking, err) l.renderExternalNotFoundOption(w, r, authReq, iam, orgIAMPolicy, human, idpLinking, err)
return return
} }
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID) authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIAMPolicy, human, idpLinking, err) l.renderExternalNotFoundOption(w, r, authReq, iam, orgIAMPolicy, human, idpLinking, err)
return return
@ -235,7 +240,7 @@ func (l *Login) handleExternalUserAuthenticated(w http.ResponseWriter, r *http.R
return return
} }
if len(externalUser.Metadatas) > 0 { if len(externalUser.Metadatas) > 0 {
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID) authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID, instanceID)
if err != nil { if err != nil {
return return
} }
@ -254,7 +259,7 @@ func (l *Login) renderExternalNotFoundOption(w http.ResponseWriter, r *http.Requ
errID, errMessage = l.getErrorMessage(r, err) errID, errMessage = l.getErrorMessage(r, err)
} }
if orgIAMPolicy == nil { if orgIAMPolicy == nil {
iam, err = l.query.IAMByID(r.Context(), domain.IAMID) iam, err = l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -324,7 +329,8 @@ func (l *Login) handleExternalNotFoundOptionCheck(w http.ResponseWriter, r *http
return return
} else if data.ResetLinking { } else if data.ResetLinking {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.ResetLinkingUsers(r.Context(), authReq.ID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.ResetLinkingUsers(r.Context(), authReq.ID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err) l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
} }
@ -335,7 +341,7 @@ func (l *Login) handleExternalNotFoundOptionCheck(w http.ResponseWriter, r *http
} }
func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) { func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) {
iam, err := l.query.IAMByID(r.Context(), domain.IAMID) iam, err := l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err) l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
return return
@ -362,6 +368,7 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
instanceID := authz.GetInstance(r.Context()).ID
if len(authReq.LinkingUsers) == 0 { if len(authReq.LinkingUsers) == 0 {
l.renderError(w, r, authReq, caos_errors.ThrowPreconditionFailed(nil, "LOGIN-asfg3", "Errors.ExternalIDP.NoExternalUserData")) l.renderError(w, r, authReq, caos_errors.ThrowPreconditionFailed(nil, "LOGIN-asfg3", "Errors.ExternalIDP.NoExternalUserData"))
return return
@ -373,12 +380,12 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIamPolicy, nil, nil, err) l.renderExternalNotFoundOption(w, r, authReq, iam, orgIamPolicy, nil, nil, err)
return return
} }
err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, memberRoles, authReq.ID, userAgentID, resourceOwner, metadata, domain.BrowserInfoFromRequest(r)) err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, memberRoles, authReq.ID, userAgentID, resourceOwner, instanceID, metadata, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.renderExternalNotFoundOption(w, r, authReq, iam, orgIamPolicy, user, externalIDP, err) l.renderExternalNotFoundOption(w, r, authReq, iam, orgIamPolicy, user, externalIDP, err)
return return
} }
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID) authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, instanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return

View File

@ -8,6 +8,7 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
iam_model "github.com/caos/zitadel/internal/iam/model" iam_model "github.com/caos/zitadel/internal/iam/model"
@ -67,7 +68,8 @@ func (l *Login) handleExternalRegister(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, idpConfig.IDPConfigID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderLogin(w, r, authReq, err) l.renderLogin(w, r, authReq, err)
return return
@ -87,7 +89,8 @@ func (l *Login) handleExternalRegisterCallback(w http.ResponseWriter, r *http.Re
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -111,7 +114,7 @@ func (l *Login) handleExternalRegisterCallback(w http.ResponseWriter, r *http.Re
} }
func (l *Login) handleExternalUserRegister(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, idpConfig *iam_model.IDPConfigView, userAgentID string, tokens *oidc.Tokens) { func (l *Login) handleExternalUserRegister(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, idpConfig *iam_model.IDPConfigView, userAgentID string, tokens *oidc.Tokens) {
iam, err := l.query.IAMByID(r.Context(), domain.IAMID) iam, err := l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderRegisterOption(w, r, authReq, err) l.renderRegisterOption(w, r, authReq, err)
return return
@ -204,7 +207,7 @@ func (l *Login) handleExternalRegisterCheck(w http.ResponseWriter, r *http.Reque
return return
} }
iam, err := l.query.IAMByID(r.Context(), domain.IAMID) iam, err := l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderRegisterOption(w, r, authReq, err) l.renderRegisterOption(w, r, authReq, err)
return return

View File

@ -12,6 +12,7 @@ import (
"github.com/caos/oidc/pkg/client/rp" "github.com/caos/oidc/pkg/client/rp"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/zitadel/internal/api/authz"
http_util "github.com/caos/zitadel/internal/api/http" http_util "github.com/caos/zitadel/internal/api/http"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
@ -44,7 +45,8 @@ func (l *Login) handleJWTRequest(w http.ResponseWriter, r *http.Request) {
l.renderError(w, r, nil, err) l.renderError(w, r, nil, err)
return return
} }
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -82,13 +84,13 @@ func (l *Login) handleJWTExtraction(w http.ResponseWriter, r *http.Request, auth
return return
} }
metadata := externalUser.Metadatas metadata := externalUser.Metadatas
err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, authReq.AgentID, externalUser, domain.BrowserInfoFromRequest(r)) err = l.authRepo.CheckExternalUserLogin(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID, externalUser, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.jwtExtractionUserNotFound(w, r, authReq, idpConfig, tokens, err) l.jwtExtractionUserNotFound(w, r, authReq, idpConfig, tokens, err)
return return
} }
if len(metadata) > 0 { if len(metadata) > 0 {
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID) authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -115,7 +117,7 @@ func (l *Login) jwtExtractionUserNotFound(w http.ResponseWriter, r *http.Request
l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err) l.renderExternalNotFoundOption(w, r, authReq, nil, nil, nil, nil, err)
return return
} }
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID) authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -133,12 +135,12 @@ func (l *Login) jwtExtractionUserNotFound(w http.ResponseWriter, r *http.Request
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
} }
err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, nil, authReq.ID, authReq.AgentID, resourceOwner, metadata, domain.BrowserInfoFromRequest(r)) err = l.authRepo.AutoRegisterExternalUser(setContext(r.Context(), resourceOwner), user, externalIDP, nil, authReq.ID, authReq.AgentID, resourceOwner, authReq.InstanceID, metadata, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
} }
authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID) authReq, err = l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
@ -207,7 +209,8 @@ func (l *Login) handleJWTCallback(w http.ResponseWriter, r *http.Request) {
l.renderError(w, r, nil, err) l.renderError(w, r, nil, err)
return return
} }
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.AuthRequestID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return

View File

@ -3,6 +3,7 @@ package login
import ( import (
"net/http" "net/http"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
) )
@ -13,7 +14,8 @@ const (
func (l *Login) linkUsers(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { func (l *Login) linkUsers(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.LinkExternalUsers(setContext(r.Context(), authReq.UserOrgID), authReq.ID, userAgentID, domain.BrowserInfoFromRequest(r)) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.LinkExternalUsers(setContext(r.Context(), authReq.UserOrgID), authReq.ID, userAgentID, instanceID, domain.BrowserInfoFromRequest(r))
l.renderLinkUsersDone(w, r, authReq, err) l.renderLinkUsersDone(w, r, authReq, err)
} }

View File

@ -3,6 +3,7 @@ package login
import ( import (
"net/http" "net/http"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
@ -59,8 +60,9 @@ func (l *Login) handleLoginNameCheck(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
instanceID := authz.GetInstance(r.Context()).ID
loginName := data.LoginName loginName := data.LoginName
err = l.authRepo.CheckLoginName(r.Context(), authReq.ID, loginName, userAgentID) err = l.authRepo.CheckLoginName(r.Context(), authReq.ID, loginName, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderLogin(w, r, authReq, err) l.renderLogin(w, r, authReq, err)
return return

View File

@ -3,6 +3,7 @@ package login
import ( import (
"net/http" "net/http"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
) )
@ -35,7 +36,8 @@ func (l *Login) handleMFAVerify(w http.ResponseWriter, r *http.Request) {
} }
if data.MFAType == domain.MFATypeOTP { if data.MFAType == domain.MFATypeOTP {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.VerifyMFAOTP(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Code, userAgentID, domain.BrowserInfoFromRequest(r)) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.VerifyMFAOTP(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Code, userAgentID, instanceID, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.renderMFAVerifySelected(w, r, authReq, step, domain.MFATypeOTP, err) l.renderMFAVerifySelected(w, r, authReq, step, domain.MFATypeOTP, err)
return return

View File

@ -6,6 +6,7 @@ import (
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
) )
@ -29,7 +30,8 @@ func (l *Login) renderU2FVerification(w http.ResponseWriter, r *http.Request, au
var webAuthNLogin *domain.WebAuthNLogin var webAuthNLogin *domain.WebAuthNLogin
if err == nil { if err == nil {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
webAuthNLogin, err = l.authRepo.BeginMFAU2FLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
webAuthNLogin, err = l.authRepo.BeginMFAU2FLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, instanceID)
} }
if err != nil { if err != nil {
errID, errMessage = l.getErrorMessage(r, err) errID, errMessage = l.getErrorMessage(r, err)
@ -70,7 +72,8 @@ func (l *Login) handleU2FVerification(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.VerifyMFAU2F(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, credData, domain.BrowserInfoFromRequest(r)) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.VerifyMFAU2F(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, instanceID, credData, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.renderU2FVerification(w, r, authReq, step.MFAProviders, err) l.renderU2FVerification(w, r, authReq, step.MFAProviders, err)
return return

View File

@ -4,8 +4,6 @@ import (
"net/http" "net/http"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
) )
const ( const (
@ -40,8 +38,7 @@ func (l *Login) handlePasswordCheck(w http.ResponseWriter, r *http.Request) {
l.renderError(w, r, authReq, err) l.renderError(w, r, authReq, err)
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) err = l.authRepo.VerifyPassword(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Password, authReq.AgentID, authReq.InstanceID, domain.BrowserInfoFromRequest(r))
err = l.authRepo.VerifyPassword(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.UserID, authReq.UserOrgID, data.Password, userAgentID, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.renderPassword(w, r, authReq, err) l.renderPassword(w, r, authReq, err)
return return

View File

@ -5,8 +5,6 @@ import (
"net/http" "net/http"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
http_mw "github.com/caos/zitadel/internal/api/http/middleware"
) )
const ( const (
@ -27,8 +25,7 @@ func (l *Login) renderPasswordlessVerification(w http.ResponseWriter, r *http.Re
var errID, errMessage, credentialData string var errID, errMessage, credentialData string
var webAuthNLogin *domain.WebAuthNLogin var webAuthNLogin *domain.WebAuthNLogin
if err == nil { if err == nil {
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) webAuthNLogin, err = l.authRepo.BeginPasswordlessLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, authReq.AgentID, authReq.InstanceID)
webAuthNLogin, err = l.authRepo.BeginPasswordlessLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID)
} }
if err != nil { if err != nil {
errID, errMessage = l.getErrorMessage(r, err) errID, errMessage = l.getErrorMessage(r, err)
@ -65,8 +62,7 @@ func (l *Login) handlePasswordlessVerification(w http.ResponseWriter, r *http.Re
l.renderPasswordlessVerification(w, r, authReq, formData.PasswordLogin, err) l.renderPasswordlessVerification(w, r, authReq, formData.PasswordLogin, err)
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) err = l.authRepo.VerifyPasswordless(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, authReq.AgentID, authReq.InstanceID, credData, domain.BrowserInfoFromRequest(r))
err = l.authRepo.VerifyPasswordless(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID, credData, domain.BrowserInfoFromRequest(r))
if err != nil { if err != nil {
l.renderPasswordlessVerification(w, r, authReq, formData.PasswordLogin, err) l.renderPasswordlessVerification(w, r, authReq, formData.PasswordLogin, err)
return return

View File

@ -5,6 +5,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
caos_errs "github.com/caos/zitadel/internal/errors" caos_errs "github.com/caos/zitadel/internal/errors"
@ -61,7 +62,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
l.renderRegister(w, r, authRequest, data, err) l.renderRegister(w, r, authRequest, data, err)
return return
} }
iam, err := l.query.IAMByID(r.Context(), domain.IAMID) iam, err := l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderRegister(w, r, authRequest, data, err) l.renderRegister(w, r, authRequest, data, err)
return return
@ -94,7 +95,8 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, user.AggregateID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, user.AggregateID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderRegister(w, r, authRequest, data, err) l.renderRegister(w, r, authRequest, data, err)
return return
@ -125,7 +127,7 @@ func (l *Login) renderRegister(w http.ResponseWriter, r *http.Request, authReque
} }
if resourceOwner == "" { if resourceOwner == "" {
iam, err := l.query.IAMByID(r.Context(), domain.IAMID) iam, err := l.query.IAM(r.Context())
if err != nil { if err != nil {
l.renderRegister(w, r, authRequest, formData, err) l.renderRegister(w, r, authRequest, formData, err)
return return

View File

@ -224,8 +224,7 @@ func (l *Login) renderNextStep(w http.ResponseWriter, r *http.Request, authReq *
l.renderInternalError(w, r, nil, caos_errs.ThrowInvalidArgument(nil, "LOGIN-Df3f2", "Errors.AuthRequest.NotFound")) l.renderInternalError(w, r, nil, caos_errs.ThrowInvalidArgument(nil, "LOGIN-Df3f2", "Errors.AuthRequest.NotFound"))
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) authReq, err := l.authRepo.AuthRequestByID(r.Context(), authReq.ID, authReq.AgentID, authReq.InstanceID)
authReq, err := l.authRepo.AuthRequestByID(r.Context(), authReq.ID, userAgentID)
if err != nil { if err != nil {
l.renderInternalError(w, r, authReq, err) l.renderInternalError(w, r, authReq, err)
return return

View File

@ -5,6 +5,7 @@ import (
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/api/authz"
http_mw "github.com/caos/zitadel/internal/api/http/middleware" http_mw "github.com/caos/zitadel/internal/api/http/middleware"
) )
@ -38,7 +39,8 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID) instanceID := authz.GetInstance(r.Context()).ID
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, instanceID)
if err != nil { if err != nil {
l.renderError(w, r, authSession, err) l.renderError(w, r, authSession, err)
return return

View File

@ -8,30 +8,30 @@ import (
type AuthRequestRepository interface { type AuthRequestRepository interface {
CreateAuthRequest(ctx context.Context, request *domain.AuthRequest) (*domain.AuthRequest, error) CreateAuthRequest(ctx context.Context, request *domain.AuthRequest) (*domain.AuthRequest, error)
AuthRequestByID(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error) AuthRequestByID(ctx context.Context, id, userAgentID, instanceID string) (*domain.AuthRequest, error)
AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error) AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID, instanceID string) (*domain.AuthRequest, error)
AuthRequestByCode(ctx context.Context, code string) (*domain.AuthRequest, error) AuthRequestByCode(ctx context.Context, code, instanceID string) (*domain.AuthRequest, error)
SaveAuthCode(ctx context.Context, id, code, userAgentID string) error SaveAuthCode(ctx context.Context, id, code, userAgentID, instanceID string) error
DeleteAuthRequest(ctx context.Context, id string) error DeleteAuthRequest(ctx context.Context, id, instanceID string) error
CheckLoginName(ctx context.Context, id, loginName, userAgentID string) error CheckLoginName(ctx context.Context, id, loginName, userAgentID, instanceID string) error
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo) error CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, user *domain.ExternalUser, info *domain.BrowserInfo) error
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error SetExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, user *domain.ExternalUser) error
SelectUser(ctx context.Context, id, userID, userAgentID string) error SelectUser(ctx context.Context, id, userID, userAgentID, instanceID string) error
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string) error SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID, instanceID string) error
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID, instanceID string, info *domain.BrowserInfo) error
VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID string, info *domain.BrowserInfo) error VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID, instanceID string, info *domain.BrowserInfo) error
BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (*domain.WebAuthNLogin, error) BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (*domain.WebAuthNLogin, error)
VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) error VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) error
BeginPasswordlessSetup(ctx context.Context, userID, resourceOwner string, preferredPlatformType domain.AuthenticatorAttachment) (login *domain.WebAuthNToken, err error) BeginPasswordlessSetup(ctx context.Context, userID, resourceOwner string, preferredPlatformType domain.AuthenticatorAttachment) (login *domain.WebAuthNToken, err error)
VerifyPasswordlessSetup(ctx context.Context, userID, resourceOwner, userAgentID, tokenName string, credentialData []byte) (err error) VerifyPasswordlessSetup(ctx context.Context, userID, resourceOwner, userAgentID, tokenName string, credentialData []byte) (err error)
BeginPasswordlessInitCodeSetup(ctx context.Context, userID, resourceOwner, codeID, verificationCode string, preferredPlatformType domain.AuthenticatorAttachment) (login *domain.WebAuthNToken, err error) BeginPasswordlessInitCodeSetup(ctx context.Context, userID, resourceOwner, codeID, verificationCode string, preferredPlatformType domain.AuthenticatorAttachment) (login *domain.WebAuthNToken, err error)
VerifyPasswordlessInitCodeSetup(ctx context.Context, userID, resourceOwner, userAgentID, tokenName, codeID, verificationCode string, credentialData []byte) (err error) VerifyPasswordlessInitCodeSetup(ctx context.Context, userID, resourceOwner, userAgentID, tokenName, codeID, verificationCode string, credentialData []byte) (err error)
BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (*domain.WebAuthNLogin, error) BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (*domain.WebAuthNLogin, error)
VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) error VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) error
LinkExternalUsers(ctx context.Context, authReqID, userAgentID string, info *domain.BrowserInfo) error LinkExternalUsers(ctx context.Context, authReqID, userAgentID, instanceID string, info *domain.BrowserInfo) error
AutoRegisterExternalUser(ctx context.Context, user *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner string, metadatas []*domain.Metadata, info *domain.BrowserInfo) error AutoRegisterExternalUser(ctx context.Context, user *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner, instanceID string, metadatas []*domain.Metadata, info *domain.BrowserInfo) error
ResetLinkingUsers(ctx context.Context, authReqID, userAgentID string) error ResetLinkingUsers(ctx context.Context, authReqID, userAgentID, instanceID string) error
} }

View File

@ -156,22 +156,22 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom
return request, nil return request, nil
} }
func (repo *AuthRequestRepo) AuthRequestByID(ctx context.Context, id, userAgentID string) (_ *domain.AuthRequest, err error) { func (repo *AuthRequestRepo) AuthRequestByID(ctx context.Context, id, userAgentID, instanceID string) (_ *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
return repo.getAuthRequestNextSteps(ctx, id, userAgentID, false) return repo.getAuthRequestNextSteps(ctx, id, userAgentID, instanceID, false)
} }
func (repo *AuthRequestRepo) AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID string) (_ *domain.AuthRequest, err error) { func (repo *AuthRequestRepo) AuthRequestByIDCheckLoggedIn(ctx context.Context, id, userAgentID, instanceID string) (_ *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
return repo.getAuthRequestNextSteps(ctx, id, userAgentID, true) return repo.getAuthRequestNextSteps(ctx, id, userAgentID, instanceID, true)
} }
func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code, userAgentID string) (err error) { func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, id, userAgentID) request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -179,10 +179,10 @@ func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code, userAge
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code string) (_ *domain.AuthRequest, err error) { func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code, instanceID string) (_ *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.AuthRequests.GetAuthRequestByCode(ctx, code) request, err := repo.AuthRequests.GetAuthRequestByCode(ctx, code, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -198,16 +198,16 @@ func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code string)
return request, nil return request, nil
} }
func (repo *AuthRequestRepo) DeleteAuthRequest(ctx context.Context, id string) (err error) { func (repo *AuthRequestRepo) DeleteAuthRequest(ctx context.Context, id, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
return repo.AuthRequests.DeleteAuthRequest(ctx, id) return repo.AuthRequests.DeleteAuthRequest(ctx, id, instanceID)
} }
func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName, userAgentID string) (err error) { func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, id, userAgentID) request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -218,10 +218,10 @@ func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName,
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string) (err error) { func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -232,10 +232,10 @@ func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, i
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, externalUser *domain.ExternalUser, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, externalUser *domain.ExternalUser, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -257,10 +257,10 @@ func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReq
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, externalUser *domain.ExternalUser) (err error) { func (repo *AuthRequestRepo) SetExternalUserLogin(ctx context.Context, authReqID, userAgentID, instanceID string, externalUser *domain.ExternalUser) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -277,10 +277,10 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID, userAgentID string) (err error) { func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID, userAgentID, instanceID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, id, userAgentID) request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -299,10 +299,10 @@ func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID, userAge
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID, instanceID string, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, id, userAgentID, userID) request, err := repo.getAuthRequestEnsureUser(ctx, id, userAgentID, userID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -328,31 +328,31 @@ func lockoutPolicyToDomain(policy *query.LockoutPolicy) *domain.LockoutPolicy {
} }
} }
func (repo *AuthRequestRepo) VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID string, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) VerifyMFAOTP(ctx context.Context, authRequestID, userID, resourceOwner, code, userAgentID, instanceID string, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID) request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil { if err != nil {
return err return err
} }
return repo.Command.HumanCheckMFAOTP(ctx, userID, code, resourceOwner, request.WithCurrentInfo(info)) return repo.Command.HumanCheckMFAOTP(ctx, userID, code, resourceOwner, request.WithCurrentInfo(info))
} }
func (repo *AuthRequestRepo) BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (login *domain.WebAuthNLogin, err error) { func (repo *AuthRequestRepo) BeginMFAU2FLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (login *domain.WebAuthNLogin, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID) request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return repo.Command.HumanBeginU2FLogin(ctx, userID, resourceOwner, request, true) return repo.Command.HumanBeginU2FLogin(ctx, userID, resourceOwner, request, true)
} }
func (repo *AuthRequestRepo) VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) VerifyMFAU2F(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID) request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -393,30 +393,30 @@ func (repo *AuthRequestRepo) VerifyPasswordlessInitCodeSetup(ctx context.Context
return err return err
} }
func (repo *AuthRequestRepo) BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string) (login *domain.WebAuthNLogin, err error) { func (repo *AuthRequestRepo) BeginPasswordlessLogin(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string) (login *domain.WebAuthNLogin, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID) request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return repo.Command.HumanBeginPasswordlessLogin(ctx, userID, resourceOwner, request, true) return repo.Command.HumanBeginPasswordlessLogin(ctx, userID, resourceOwner, request, true)
} }
func (repo *AuthRequestRepo) VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID string, credentialData []byte, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) VerifyPasswordless(ctx context.Context, userID, resourceOwner, authRequestID, userAgentID, instanceID string, credentialData []byte, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID) request, err := repo.getAuthRequestEnsureUser(ctx, authRequestID, userAgentID, userID, instanceID)
if err != nil { if err != nil {
return err return err
} }
return repo.Command.HumanFinishPasswordlessLogin(ctx, userID, resourceOwner, credentialData, request, true) return repo.Command.HumanFinishPasswordlessLogin(ctx, userID, resourceOwner, credentialData, request, true)
} }
func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, userAgentID string, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, userAgentID, instanceID string, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -432,8 +432,8 @@ func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, u
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, userAgentID string) error { func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, userAgentID, instanceID string) error {
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -442,10 +442,10 @@ func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, u
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner string, metadatas []*domain.Metadata, info *domain.BrowserInfo) (err error) { func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner, instanceID string, metadatas []*domain.Metadata, info *domain.BrowserInfo) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID, instanceID)
if err != nil { if err != nil {
return err return err
} }
@ -478,8 +478,8 @@ func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, regis
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) getAuthRequestNextSteps(ctx context.Context, id, userAgentID string, checkLoggedIn bool) (*domain.AuthRequest, error) { func (repo *AuthRequestRepo) getAuthRequestNextSteps(ctx context.Context, id, userAgentID, instanceID string, checkLoggedIn bool) (*domain.AuthRequest, error) {
request, err := repo.getAuthRequest(ctx, id, userAgentID) request, err := repo.getAuthRequest(ctx, id, userAgentID, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -491,8 +491,8 @@ func (repo *AuthRequestRepo) getAuthRequestNextSteps(ctx context.Context, id, us
return request, nil return request, nil
} }
func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authRequestID, userAgentID, userID string) (*domain.AuthRequest, error) { func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authRequestID, userAgentID, userID, instanceID string) (*domain.AuthRequest, error) {
request, err := repo.getAuthRequest(ctx, authRequestID, userAgentID) request, err := repo.getAuthRequest(ctx, authRequestID, userAgentID, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -506,8 +506,8 @@ func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authR
return request, nil return request, nil
} }
func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error) { func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id, userAgentID, instanceID string) (*domain.AuthRequest, error) {
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id) request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id, instanceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -3,7 +3,7 @@ package handler
import ( import (
"github.com/caos/logging" "github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore/v1" v1 "github.com/caos/zitadel/internal/eventstore/v1"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models" es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query" "github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler" "github.com/caos/zitadel/internal/eventstore/v1/spooler"
@ -76,6 +76,7 @@ func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) {
case model.ProjectAdded: case model.ProjectAdded:
mapping.OrgID = event.ResourceOwner mapping.OrgID = event.ResourceOwner
mapping.ProjectID = event.AggregateID mapping.ProjectID = event.AggregateID
mapping.InstanceID = event.InstanceID
case model.ProjectRemoved: case model.ProjectRemoved:
err := p.view.DeleteOrgProjectMappingsByProjectID(event.AggregateID) err := p.view.DeleteOrgProjectMappingsByProjectID(event.AggregateID)
if err == nil { if err == nil {
@ -87,6 +88,7 @@ func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) {
mapping.OrgID = projectGrant.GrantedOrgID mapping.OrgID = projectGrant.GrantedOrgID
mapping.ProjectID = event.AggregateID mapping.ProjectID = event.AggregateID
mapping.ProjectGrantID = projectGrant.GrantID mapping.ProjectGrantID = projectGrant.GrantID
mapping.InstanceID = projectGrant.InstanceID
case model.ProjectGrantRemoved: case model.ProjectGrantRemoved:
projectGrant := new(view_model.ProjectGrant) projectGrant := new(view_model.ProjectGrant)
projectGrant.SetData(event) projectGrant.SetData(event)

View File

@ -2,9 +2,10 @@ package handler
import ( import (
"github.com/caos/logging" "github.com/caos/logging"
req_model "github.com/caos/zitadel/internal/auth_request/model" req_model "github.com/caos/zitadel/internal/auth_request/model"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1" v1 "github.com/caos/zitadel/internal/eventstore/v1"
"github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query" "github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler" "github.com/caos/zitadel/internal/eventstore/v1/spooler"
@ -104,6 +105,7 @@ func (u *UserSession) Reduce(event *models.Event) (err error) {
UserAgentID: eventData.UserAgentID, UserAgentID: eventData.UserAgentID,
UserID: event.AggregateID, UserID: event.AggregateID,
State: int32(req_model.UserSessionStateActive), State: int32(req_model.UserSessionStateActive),
InstanceID: event.InstanceID,
} }
} }
return u.updateSession(session, event) return u.updateSession(session, event)

View File

@ -26,38 +26,38 @@ func (c *AuthRequestCache) Health(ctx context.Context) error {
return c.client.PingContext(ctx) return c.client.PingContext(ctx)
} }
func (c *AuthRequestCache) GetAuthRequestByID(_ context.Context, id string) (*domain.AuthRequest, error) { func (c *AuthRequestCache) GetAuthRequestByID(_ context.Context, id, instanceID string) (*domain.AuthRequest, error) {
return c.getAuthRequest("id", id) return c.getAuthRequest("id", id, instanceID)
} }
func (c *AuthRequestCache) GetAuthRequestByCode(_ context.Context, code string) (*domain.AuthRequest, error) { func (c *AuthRequestCache) GetAuthRequestByCode(_ context.Context, code, instanceID string) (*domain.AuthRequest, error) {
return c.getAuthRequest("code", code) return c.getAuthRequest("code", code, instanceID)
} }
func (c *AuthRequestCache) SaveAuthRequest(_ context.Context, request *domain.AuthRequest) error { func (c *AuthRequestCache) SaveAuthRequest(_ context.Context, request *domain.AuthRequest) error {
return c.saveAuthRequest(request, "INSERT INTO auth.auth_requests (id, request, creation_date, change_date, request_type) VALUES($1, $2, $3, $3, $4)", request.CreationDate, request.Request.Type()) return c.saveAuthRequest(request, "INSERT INTO auth.auth_requests (id, request, instance_id, creation_date, change_date, request_type) VALUES($1, $2, $3, $3, $4, $5)", request.CreationDate, request.InstanceID, request.Request.Type())
} }
func (c *AuthRequestCache) UpdateAuthRequest(_ context.Context, request *domain.AuthRequest) error { func (c *AuthRequestCache) UpdateAuthRequest(_ context.Context, request *domain.AuthRequest) error {
if request.ChangeDate.IsZero() { if request.ChangeDate.IsZero() {
request.ChangeDate = time.Now() request.ChangeDate = time.Now()
} }
return c.saveAuthRequest(request, "UPDATE auth.auth_requests SET request = $2, change_date = $3, code = $4 WHERE id = $1", request.ChangeDate, request.Code) return c.saveAuthRequest(request, "UPDATE auth.auth_requests SET request = $2, instance_id = $3 change_date = $4, code = $5 WHERE id = $1", request.ChangeDate, request.InstanceID, request.Code)
} }
func (c *AuthRequestCache) DeleteAuthRequest(_ context.Context, id string) error { func (c *AuthRequestCache) DeleteAuthRequest(_ context.Context, id, instanceID string) error {
_, err := c.client.Exec("DELETE FROM auth.auth_requests WHERE id = $1", id) _, err := c.client.Exec("DELETE FROM auth.auth_requests WHERE instance = $1 and id = $2", instanceID, id)
if err != nil { if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-dsHw3", "unable to delete auth request") return caos_errs.ThrowInternal(err, "CACHE-dsHw3", "unable to delete auth request")
} }
return nil return nil
} }
func (c *AuthRequestCache) getAuthRequest(key, value string) (*domain.AuthRequest, error) { func (c *AuthRequestCache) getAuthRequest(key, value, instanceID string) (*domain.AuthRequest, error) {
var b []byte var b []byte
var requestType domain.AuthRequestType var requestType domain.AuthRequestType
query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE %s = $1", key) query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE instance = $1 and %s = $2", key)
err := c.client.QueryRow(query, value).Scan(&b, &requestType) err := c.client.QueryRow(query, instanceID, value).Scan(&b, &requestType)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "Errors.AuthRequest.NotFound") return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "Errors.AuthRequest.NotFound")
@ -74,7 +74,7 @@ func (c *AuthRequestCache) getAuthRequest(key, value string) (*domain.AuthReques
return request, nil return request, nil
} }
func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query string, date time.Time, param interface{}) error { func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query string, date time.Time, instanceID string, param interface{}) error {
b, err := json.Marshal(request) b, err := json.Marshal(request)
if err != nil { if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-os0GH", "Errors.Internal") return caos_errs.ThrowInternal(err, "CACHE-os0GH", "Errors.Internal")
@ -83,7 +83,7 @@ func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query st
if err != nil { if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-su3GK", "Errors.Internal") return caos_errs.ThrowInternal(err, "CACHE-su3GK", "Errors.Internal")
} }
_, err = stmt.Exec(request.ID, b, date, param) _, err = stmt.Exec(request.ID, b, date, instanceID, param)
if err != nil { if err != nil {
return caos_errs.ThrowInternal(err, "CACHE-sj8iS", "Errors.Internal") return caos_errs.ThrowInternal(err, "CACHE-sj8iS", "Errors.Internal")
} }

View File

@ -6,79 +6,80 @@ package mock
import ( import (
context "context" context "context"
reflect "reflect"
domain "github.com/caos/zitadel/internal/domain" domain "github.com/caos/zitadel/internal/domain"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockAuthRequestCache is a mock of AuthRequestCache interface // MockAuthRequestCache is a mock of AuthRequestCache interface.
type MockAuthRequestCache struct { type MockAuthRequestCache struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockAuthRequestCacheMockRecorder recorder *MockAuthRequestCacheMockRecorder
} }
// MockAuthRequestCacheMockRecorder is the mock recorder for MockAuthRequestCache // MockAuthRequestCacheMockRecorder is the mock recorder for MockAuthRequestCache.
type MockAuthRequestCacheMockRecorder struct { type MockAuthRequestCacheMockRecorder struct {
mock *MockAuthRequestCache mock *MockAuthRequestCache
} }
// NewMockAuthRequestCache creates a new mock instance // NewMockAuthRequestCache creates a new mock instance.
func NewMockAuthRequestCache(ctrl *gomock.Controller) *MockAuthRequestCache { func NewMockAuthRequestCache(ctrl *gomock.Controller) *MockAuthRequestCache {
mock := &MockAuthRequestCache{ctrl: ctrl} mock := &MockAuthRequestCache{ctrl: ctrl}
mock.recorder = &MockAuthRequestCacheMockRecorder{mock} mock.recorder = &MockAuthRequestCacheMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder { func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder {
return m.recorder return m.recorder
} }
// DeleteAuthRequest mocks base method // DeleteAuthRequest mocks base method.
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error { func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1) ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest // DeleteAuthRequest indicates an expected call of DeleteAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1, arg2)
} }
// GetAuthRequestByCode mocks base method // GetAuthRequestByCode mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) { func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1) ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1, arg2)
ret0, _ := ret[0].(*domain.AuthRequest) ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode // GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1, arg2)
} }
// GetAuthRequestByID mocks base method // GetAuthRequestByID mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) { func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1) ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1, arg2)
ret0, _ := ret[0].(*domain.AuthRequest) ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID // GetAuthRequestByID indicates an expected call of GetAuthRequestByID.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1, arg2)
} }
// Health mocks base method // Health mocks base method.
func (m *MockAuthRequestCache) Health(arg0 context.Context) error { func (m *MockAuthRequestCache) Health(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0) ret := m.ctrl.Call(m, "Health", arg0)
@ -86,13 +87,13 @@ func (m *MockAuthRequestCache) Health(arg0 context.Context) error {
return ret0 return ret0
} }
// Health indicates an expected call of Health // Health indicates an expected call of Health.
func (mr *MockAuthRequestCacheMockRecorder) Health(arg0 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) Health(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockAuthRequestCache)(nil).Health), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockAuthRequestCache)(nil).Health), arg0)
} }
// SaveAuthRequest mocks base method // SaveAuthRequest mocks base method.
func (m *MockAuthRequestCache) SaveAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) error { func (m *MockAuthRequestCache) SaveAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveAuthRequest", arg0, arg1) ret := m.ctrl.Call(m, "SaveAuthRequest", arg0, arg1)
@ -100,13 +101,13 @@ func (m *MockAuthRequestCache) SaveAuthRequest(arg0 context.Context, arg1 *domai
return ret0 return ret0
} }
// SaveAuthRequest indicates an expected call of SaveAuthRequest // SaveAuthRequest indicates an expected call of SaveAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) SaveAuthRequest(arg0, arg1 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) SaveAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).SaveAuthRequest), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).SaveAuthRequest), arg0, arg1)
} }
// UpdateAuthRequest mocks base method // UpdateAuthRequest mocks base method.
func (m *MockAuthRequestCache) UpdateAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) error { func (m *MockAuthRequestCache) UpdateAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAuthRequest", arg0, arg1) ret := m.ctrl.Call(m, "UpdateAuthRequest", arg0, arg1)
@ -114,7 +115,7 @@ func (m *MockAuthRequestCache) UpdateAuthRequest(arg0 context.Context, arg1 *dom
return ret0 return ret0
} }
// UpdateAuthRequest indicates an expected call of UpdateAuthRequest // UpdateAuthRequest indicates an expected call of UpdateAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) UpdateAuthRequest(arg0, arg1 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) UpdateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).UpdateAuthRequest), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).UpdateAuthRequest), arg0, arg1)

View File

@ -2,15 +2,16 @@ package repository
import ( import (
"context" "context"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
) )
type AuthRequestCache interface { type AuthRequestCache interface {
Health(ctx context.Context) error Health(ctx context.Context) error
GetAuthRequestByID(ctx context.Context, id string) (*domain.AuthRequest, error) GetAuthRequestByID(ctx context.Context, id, instanceID string) (*domain.AuthRequest, error)
GetAuthRequestByCode(ctx context.Context, code string) (*domain.AuthRequest, error) GetAuthRequestByCode(ctx context.Context, code, instanceID string) (*domain.AuthRequest, error)
SaveAuthRequest(ctx context.Context, request *domain.AuthRequest) error SaveAuthRequest(ctx context.Context, request *domain.AuthRequest) error
UpdateAuthRequest(ctx context.Context, request *domain.AuthRequest) error UpdateAuthRequest(ctx context.Context, request *domain.AuthRequest) error
DeleteAuthRequest(ctx context.Context, id string) error DeleteAuthRequest(ctx context.Context, id, instanceID string) error
} }

View File

@ -236,7 +236,7 @@ func (repo *TokenVerifierRepo) VerifierClientID(ctx context.Context, appName str
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
iam, err := repo.Query.IAMByID(ctx, domain.IAMID) iam, err := repo.Query.IAM(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }

View File

@ -28,6 +28,7 @@ func (repo *UserMembershipRepo) SearchMyMemberships(ctx context.Context) ([]*aut
func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*user_view_model.UserMembershipView, error) { func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*user_view_model.UserMembershipView, error) {
ctxData := authz.GetCtxData(ctx) ctxData := authz.GetCtxData(ctx)
instance := authz.GetInstance(ctx)
orgMemberships, orgCount, err := repo.View.SearchUserMemberships(&user_model.UserMembershipSearchRequest{ orgMemberships, orgCount, err := repo.View.SearchUserMemberships(&user_model.UserMembershipSearchRequest{
Queries: []*user_model.UserMembershipSearchQuery{ Queries: []*user_model.UserMembershipSearchQuery{
{ {
@ -40,6 +41,11 @@ func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*u
Method: domain.SearchMethodEquals, Method: domain.SearchMethodEquals,
Value: ctxData.OrgID, Value: ctxData.OrgID,
}, },
{
Key: user_model.UserMembershipSearchKeyInstanceID,
Method: domain.SearchMethodEquals,
Value: instance.ID,
},
}, },
}) })
if err != nil { if err != nil {
@ -57,6 +63,11 @@ func (repo *UserMembershipRepo) searchUserMemberships(ctx context.Context) ([]*u
Method: domain.SearchMethodEquals, Method: domain.SearchMethodEquals,
Value: domain.IAMID, Value: domain.IAMID,
}, },
{
Key: user_model.UserMembershipSearchKeyInstanceID,
Method: domain.SearchMethodEquals,
Value: instance.ID,
},
}, },
}) })
if err != nil { if err != nil {

View File

@ -30,8 +30,6 @@ func (h *handler) Eventstore() v1.Eventstore {
func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es v1.Eventstore, systemDefaults sd.SystemDefaults) []query.Handler { func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es v1.Eventstore, systemDefaults sd.SystemDefaults) []query.Handler {
return []query.Handler{ return []query.Handler{
newUserGrant(
handler{view, bulkLimit, configs.cycleDuration("UserGrants"), errorCount, es}),
newUserMembership( newUserMembership(
handler{view, bulkLimit, configs.cycleDuration("UserMemberships"), errorCount, es}), handler{view, bulkLimit, configs.cycleDuration("UserMemberships"), errorCount, es}),
} }

View File

@ -1,313 +0,0 @@
package handler
import (
"context"
"strings"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
es_sdk "github.com/caos/zitadel/internal/eventstore/v1/sdk"
iam_model "github.com/caos/zitadel/internal/iam/model"
iam_view "github.com/caos/zitadel/internal/iam/repository/view"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
caos_errs "github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler"
iam_es_model "github.com/caos/zitadel/internal/iam/repository/eventsourcing/model"
org_es_model "github.com/caos/zitadel/internal/org/repository/eventsourcing/model"
proj_es_model "github.com/caos/zitadel/internal/project/repository/eventsourcing/model"
view_model "github.com/caos/zitadel/internal/usergrant/repository/view/model"
)
const (
userGrantTable = "authz.user_grants"
)
type UserGrant struct {
handler
iamProjectID string
subscription *v1.Subscription
}
func newUserGrant(
handler handler,
) *UserGrant {
h := &UserGrant{
handler: handler,
}
h.subscribe()
return h
}
func (k *UserGrant) subscribe() {
k.subscription = k.es.Subscribe(k.AggregateTypes()...)
go func() {
for event := range k.subscription.Events {
query.ReduceEvent(k, event)
}
}()
}
func (u *UserGrant) ViewModel() string {
return userGrantTable
}
func (u *UserGrant) Subscription() *v1.Subscription {
return u.subscription
}
func (_ *UserGrant) AggregateTypes() []es_models.AggregateType {
return []es_models.AggregateType{iam_es_model.IAMAggregate, org_es_model.OrgAggregate, proj_es_model.ProjectAggregate}
}
func (u *UserGrant) CurrentSequence() (uint64, error) {
sequence, err := u.view.GetLatestUserGrantSequence()
if err != nil {
return 0, err
}
return sequence.CurrentSequence, nil
}
func (u *UserGrant) EventQuery() (*es_models.SearchQuery, error) {
if u.iamProjectID == "" {
err := u.setIamProjectID()
if err != nil {
return nil, err
}
}
sequence, err := u.view.GetLatestUserGrantSequence()
if err != nil {
return nil, err
}
return es_models.NewSearchQuery().
AggregateTypeFilter(iam_es_model.IAMAggregate, org_es_model.OrgAggregate, proj_es_model.ProjectAggregate).
LatestSequenceFilter(sequence.CurrentSequence), nil
}
func (u *UserGrant) Reduce(event *es_models.Event) (err error) {
switch event.AggregateType {
case proj_es_model.ProjectAggregate:
err = u.processProject(event)
case iam_es_model.IAMAggregate:
err = u.processIAMMember(event, "IAM", false)
case org_es_model.OrgAggregate:
return u.processOrg(event)
}
return err
}
func (u *UserGrant) processProject(event *es_models.Event) (err error) {
switch event.Type {
case proj_es_model.ProjectMemberAdded, proj_es_model.ProjectMemberChanged,
proj_es_model.ProjectMemberRemoved, proj_es_model.ProjectMemberCascadeRemoved:
member := new(proj_es_model.ProjectMember)
err := member.SetData(event)
if err != nil {
return err
}
return u.processMember(event, "PROJECT", event.AggregateID, member.UserID, member.Roles)
case proj_es_model.ProjectGrantMemberAdded, proj_es_model.ProjectGrantMemberChanged,
proj_es_model.ProjectGrantMemberRemoved,
proj_es_model.ProjectGrantMemberCascadeRemoved:
member := new(proj_es_model.ProjectGrantMember)
err := member.SetData(event)
if err != nil {
return err
}
return u.processMember(event, "PROJECT_GRANT", member.GrantID, member.UserID, member.Roles)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func (u *UserGrant) processOrg(event *es_models.Event) (err error) {
switch event.Type {
case org_es_model.OrgMemberAdded, org_es_model.OrgMemberChanged,
org_es_model.OrgMemberRemoved, org_es_model.OrgMemberCascadeRemoved:
member := new(org_es_model.OrgMember)
err := member.SetData(event)
if err != nil {
return err
}
return u.processMember(event, "ORG", "", member.UserID, member.Roles)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func (u *UserGrant) processIAMMember(event *es_models.Event, rolePrefix string, suffix bool) error {
member := new(iam_es_model.IAMMember)
switch event.Type {
case iam_es_model.IAMMemberAdded, iam_es_model.IAMMemberChanged:
member.SetData(event)
grant, err := u.view.UserGrantByIDs(domain.IAMID, u.iamProjectID, member.UserID)
if err != nil && !errors.IsNotFound(err) {
return err
}
if errors.IsNotFound(err) {
grant = &view_model.UserGrantView{
ID: u.iamProjectID + member.UserID,
ResourceOwner: domain.IAMID,
OrgName: domain.IAMID,
ProjectID: u.iamProjectID,
UserID: member.UserID,
RoleKeys: member.Roles,
CreationDate: event.CreationDate,
}
if suffix {
grant.RoleKeys = suffixRoles(event.AggregateID, grant.RoleKeys)
}
} else {
newRoles := member.Roles
if grant.RoleKeys != nil {
grant.RoleKeys = mergeExistingRoles(rolePrefix, "", grant.RoleKeys, newRoles)
} else {
grant.RoleKeys = newRoles
}
}
grant.Sequence = event.Sequence
grant.ChangeDate = event.CreationDate
return u.view.PutUserGrant(grant, event)
case iam_es_model.IAMMemberRemoved,
iam_es_model.IAMMemberCascadeRemoved:
member.SetData(event)
grant, err := u.view.UserGrantByIDs(domain.IAMID, u.iamProjectID, member.UserID)
if err != nil {
return err
}
return u.view.DeleteUserGrant(grant.ID, event)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func (u *UserGrant) processMember(event *es_models.Event, rolePrefix, roleSuffix string, userID string, roleKeys []string) error {
switch event.Type {
case org_es_model.OrgMemberAdded, proj_es_model.ProjectMemberAdded, proj_es_model.ProjectGrantMemberAdded,
org_es_model.OrgMemberChanged, proj_es_model.ProjectMemberChanged, proj_es_model.ProjectGrantMemberChanged:
grant, err := u.view.UserGrantByIDs(event.ResourceOwner, u.iamProjectID, userID)
if err != nil && !errors.IsNotFound(err) {
return err
}
if roleSuffix != "" {
roleKeys = suffixRoles(roleSuffix, roleKeys)
}
if errors.IsNotFound(err) {
grant = &view_model.UserGrantView{
ID: u.iamProjectID + event.ResourceOwner + userID,
ResourceOwner: event.ResourceOwner,
ProjectID: u.iamProjectID,
UserID: userID,
RoleKeys: roleKeys,
CreationDate: event.CreationDate,
}
} else {
newRoles := roleKeys
if grant.RoleKeys != nil {
grant.RoleKeys = mergeExistingRoles(rolePrefix, roleSuffix, grant.RoleKeys, newRoles)
} else {
grant.RoleKeys = newRoles
}
}
grant.Sequence = event.Sequence
grant.ChangeDate = event.CreationDate
return u.view.PutUserGrant(grant, event)
case org_es_model.OrgMemberRemoved,
org_es_model.OrgMemberCascadeRemoved,
proj_es_model.ProjectMemberRemoved,
proj_es_model.ProjectMemberCascadeRemoved,
proj_es_model.ProjectGrantMemberRemoved,
proj_es_model.ProjectGrantMemberCascadeRemoved:
grant, err := u.view.UserGrantByIDs(event.ResourceOwner, u.iamProjectID, userID)
if err != nil && !errors.IsNotFound(err) {
return err
}
if errors.IsNotFound(err) {
return u.view.ProcessedUserGrantSequence(event)
}
if roleSuffix != "" {
roleKeys = suffixRoles(roleSuffix, roleKeys)
}
if grant.RoleKeys == nil {
return u.view.ProcessedUserGrantSequence(event)
}
grant.RoleKeys = mergeExistingRoles(rolePrefix, roleSuffix, grant.RoleKeys, nil)
return u.view.PutUserGrant(grant, event)
default:
return u.view.ProcessedUserGrantSequence(event)
}
}
func suffixRoles(suffix string, roles []string) []string {
suffixedRoles := make([]string, len(roles))
for i := 0; i < len(roles); i++ {
suffixedRoles[i] = roles[i] + ":" + suffix
}
return suffixedRoles
}
func mergeExistingRoles(rolePrefix, suffix string, existingRoles, newRoles []string) []string {
mergedRoles := make([]string, 0)
for _, existingRole := range existingRoles {
if !strings.HasPrefix(existingRole, rolePrefix) {
mergedRoles = append(mergedRoles, existingRole)
continue
}
if suffix != "" && !strings.HasSuffix(existingRole, suffix) {
mergedRoles = append(mergedRoles, existingRole)
}
}
return append(mergedRoles, newRoles...)
}
func (u *UserGrant) setIamProjectID() error {
if u.iamProjectID != "" {
return nil
}
iam, err := u.getIAMByID(context.Background())
if err != nil {
return err
}
if iam.SetUpDone < domain.StepCount-1 {
return caos_errs.ThrowPreconditionFailed(nil, "HANDL-s5DTs", "Setup not done")
}
u.iamProjectID = iam.IAMProjectID
return nil
}
func (u *UserGrant) OnError(event *es_models.Event, err error) error {
logging.LogWithFields("SPOOL-VcVoJ", "id", event.AggregateID).WithError(err).Warn("something went wrong in user grant handler")
return spooler.HandleError(event, err, u.view.GetLatestUserGrantFailedEvent, u.view.ProcessedUserGrantFailedEvent, u.view.ProcessedUserGrantSequence, u.errorCountUntilSkip)
}
func (u *UserGrant) OnSuccess() error {
return spooler.HandleSuccess(u.view.UpdateUserGrantSpoolerRunTimestamp)
}
func (u *UserGrant) getIAMByID(ctx context.Context) (*iam_model.IAM, error) {
query, err := iam_view.IAMByIDQuery(domain.IAMID, 0)
if err != nil {
return nil, err
}
iam := &iam_es_model.IAM{
ObjectRoot: es_models.ObjectRoot{
AggregateID: domain.IAMID,
},
}
err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, iam.AppendEvents, query)
if err != nil && errors.IsNotFound(err) && iam.Sequence == 0 {
return nil, err
}
return iam_es_model.IAMToModel(iam), nil
}

View File

@ -1,70 +0,0 @@
package view
import (
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
grant_model "github.com/caos/zitadel/internal/usergrant/model"
"github.com/caos/zitadel/internal/usergrant/repository/view"
"github.com/caos/zitadel/internal/usergrant/repository/view/model"
"github.com/caos/zitadel/internal/view/repository"
)
const (
userGrantTable = "authz.user_grants"
)
func (v *View) UserGrantByID(grantID string) (*model.UserGrantView, error) {
return view.UserGrantByID(v.Db, userGrantTable, grantID)
}
func (v *View) UserGrantByIDs(resourceOwnerID, projectID, userID string) (*model.UserGrantView, error) {
return view.UserGrantByIDs(v.Db, userGrantTable, resourceOwnerID, projectID, userID)
}
func (v *View) UserGrantsByUserID(userID string) ([]*model.UserGrantView, error) {
return view.UserGrantsByUserID(v.Db, userGrantTable, userID)
}
func (v *View) UserGrantsByProjectID(projectID string) ([]*model.UserGrantView, error) {
return view.UserGrantsByProjectID(v.Db, userGrantTable, projectID)
}
func (v *View) SearchUserGrants(request *grant_model.UserGrantSearchRequest) ([]*model.UserGrantView, uint64, error) {
return view.SearchUserGrants(v.Db, userGrantTable, request)
}
func (v *View) PutUserGrant(grant *model.UserGrantView, event *models.Event) error {
err := view.PutUserGrant(v.Db, userGrantTable, grant)
if err != nil {
return err
}
return v.ProcessedUserGrantSequence(event)
}
func (v *View) DeleteUserGrant(grantID string, event *models.Event) error {
err := view.DeleteUserGrant(v.Db, userGrantTable, grantID)
if err != nil && !errors.IsNotFound(err) {
return err
}
return v.ProcessedUserGrantSequence(event)
}
func (v *View) GetLatestUserGrantSequence() (*repository.CurrentSequence, error) {
return v.latestSequence(userGrantTable)
}
func (v *View) ProcessedUserGrantSequence(event *models.Event) error {
return v.saveCurrentSequence(userGrantTable, event)
}
func (v *View) UpdateUserGrantSpoolerRunTimestamp() error {
return v.updateSpoolerRunSequence(userGrantTable)
}
func (v *View) GetLatestUserGrantFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
return v.latestFailedEvent(userGrantTable, sequence)
}
func (v *View) ProcessedUserGrantFailedEvent(failedEvent *repository.FailedEvent) error {
return v.saveFailedEvent(failedEvent)
}

View File

@ -23,6 +23,7 @@ type AuthRequest struct {
UiLocales []string UiLocales []string
LoginHint string LoginHint string
MaxAuthAge *time.Duration MaxAuthAge *time.Duration
InstanceID string
Request Request Request Request
levelOfAssurance LevelOfAssurance levelOfAssurance LevelOfAssurance

View File

@ -21,7 +21,7 @@ func NewAggregate(
ID: id, ID: id,
Type: typ, Type: typ,
ResourceOwner: authz.GetCtxData(ctx).OrgID, ResourceOwner: authz.GetCtxData(ctx).OrgID,
Tenant: authz.GetCtxData(ctx).TenantID, InstanceID: authz.GetInstance(ctx).ID,
Version: version, Version: version,
} }
@ -50,7 +50,7 @@ func AggregateFromWriteModel(
ID: wm.AggregateID, ID: wm.AggregateID,
Type: typ, Type: typ,
ResourceOwner: wm.ResourceOwner, ResourceOwner: wm.ResourceOwner,
Tenant: wm.Tenant, InstanceID: wm.InstanceID,
Version: version, Version: version,
} }
} }
@ -63,8 +63,8 @@ type Aggregate struct {
Type AggregateType `json:"-"` Type AggregateType `json:"-"`
//ResourceOwner is the org this aggregates belongs to //ResourceOwner is the org this aggregates belongs to
ResourceOwner string `json:"-"` ResourceOwner string `json:"-"`
//Tenant is the system this aggregate belongs to //InstanceID is the instance this aggregate belongs to
Tenant string `json:"-"` InstanceID string `json:"-"`
//Version is the semver this aggregate represents //Version is the semver this aggregate represents
Version Version `json:"-"` Version Version `json:"-"`
} }

View File

@ -79,7 +79,7 @@ func BaseEventFromRepo(event *repository.Event) *BaseEvent {
ID: event.AggregateID, ID: event.AggregateID,
Type: AggregateType(event.AggregateType), Type: AggregateType(event.AggregateType),
ResourceOwner: event.ResourceOwner.String, ResourceOwner: event.ResourceOwner.String,
Tenant: event.Tenant.String, InstanceID: event.InstanceID.String,
Version: Version(event.Version), Version: Version(event.Version),
}, },
EventType: EventType(event.Type), EventType: EventType(event.Type),

View File

@ -41,7 +41,7 @@ func (es *Eventstore) Health(ctx context.Context) error {
//Push pushes the events in a single transaction //Push pushes the events in a single transaction
// an event needs at least an aggregate // an event needs at least an aggregate
func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error) { func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error) {
events, constraints, err := commandsToRepository(authz.GetCtxData(ctx).TenantID, cmds) events, constraints, err := commandsToRepository(authz.GetInstance(ctx).ID, cmds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -59,7 +59,7 @@ func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error
return eventReaders, nil return eventReaders, nil
} }
func commandsToRepository(tenantID string, cmds []Command) (events []*repository.Event, constraints []*repository.UniqueConstraint, err error) { func commandsToRepository(instanceID string, cmds []Command) (events []*repository.Event, constraints []*repository.UniqueConstraint, err error) {
events = make([]*repository.Event, len(cmds)) events = make([]*repository.Event, len(cmds))
for i, cmd := range cmds { for i, cmd := range cmds {
data, err := EventData(cmd) data, err := EventData(cmd)
@ -82,7 +82,7 @@ func commandsToRepository(tenantID string, cmds []Command) (events []*repository
AggregateID: cmd.Aggregate().ID, AggregateID: cmd.Aggregate().ID,
AggregateType: repository.AggregateType(cmd.Aggregate().Type), AggregateType: repository.AggregateType(cmd.Aggregate().Type),
ResourceOwner: sql.NullString{String: cmd.Aggregate().ResourceOwner, Valid: cmd.Aggregate().ResourceOwner != ""}, ResourceOwner: sql.NullString{String: cmd.Aggregate().ResourceOwner, Valid: cmd.Aggregate().ResourceOwner != ""},
Tenant: sql.NullString{String: tenantID, Valid: tenantID != ""}, InstanceID: sql.NullString{String: instanceID, Valid: instanceID != ""},
EditorService: cmd.EditorService(), EditorService: cmd.EditorService(),
EditorUser: cmd.EditorUser(), EditorUser: cmd.EditorUser(),
Type: repository.EventType(cmd.Type()), Type: repository.EventType(cmd.Type()),
@ -113,7 +113,7 @@ func uniqueConstraintsToRepository(constraints []*EventUniqueConstraint) (unique
//Filter filters the stored events based on the searchQuery //Filter filters the stored events based on the searchQuery
// and maps the events to the defined event structs // and maps the events to the defined event structs
func (es *Eventstore) Filter(ctx context.Context, queryFactory *SearchQueryBuilder) ([]Event, error) { func (es *Eventstore) Filter(ctx context.Context, queryFactory *SearchQueryBuilder) ([]Event, error) {
query, err := queryFactory.build(authz.GetCtxData(ctx).TenantID) query, err := queryFactory.build(authz.GetInstance(ctx).ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -170,7 +170,7 @@ func (es *Eventstore) FilterToReducer(ctx context.Context, searchQuery *SearchQu
//LatestSequence filters the latest sequence for the given search query //LatestSequence filters the latest sequence for the given search query
func (es *Eventstore) LatestSequence(ctx context.Context, queryFactory *SearchQueryBuilder) (uint64, error) { func (es *Eventstore) LatestSequence(ctx context.Context, queryFactory *SearchQueryBuilder) (uint64, error) {
query, err := queryFactory.build(authz.GetCtxData(ctx).TenantID) query, err := queryFactory.build(authz.GetInstance(ctx).ID)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -29,7 +29,7 @@ func newTestEvent(id, description string, data func() interface{}, checkPrevious
data: data, data: data,
shouldCheckPrevious: checkPrevious, shouldCheckPrevious: checkPrevious,
BaseEvent: *NewBaseEventForPush( BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"), service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate(authz.NewMockContext("zitadel", "caos", "adlerhurst"), id, "test.aggregate", "v1"), NewAggregate(authz.NewMockContext("zitadel", "caos", "adlerhurst"), id, "test.aggregate", "v1"),
"test.event", "test.event",
), ),
@ -344,8 +344,8 @@ func Test_eventData(t *testing.T) {
func TestEventstore_aggregatesToEvents(t *testing.T) { func TestEventstore_aggregatesToEvents(t *testing.T) {
type args struct { type args struct {
tenantID string instanceID string
events []Command events []Command
} }
type res struct { type res struct {
wantErr bool wantErr bool
@ -359,7 +359,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
{ {
name: "one aggregate one event", name: "one aggregate one event",
args: args{ args: args{
tenantID: "tenant", instanceID: "instanceID",
events: []Command{ events: []Command{
newTestEvent( newTestEvent(
"1", "1",
@ -380,7 +380,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "tenant", Valid: true}, InstanceID: sql.NullString{String: "instanceID", Valid: true},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -390,7 +390,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
{ {
name: "one aggregate multiple events", name: "one aggregate multiple events",
args: args{ args: args{
tenantID: "tenant", instanceID: "instanceID",
events: []Command{ events: []Command{
newTestEvent( newTestEvent(
"1", "1",
@ -418,7 +418,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "tenant", Valid: true}, InstanceID: sql.NullString{String: "instanceID", Valid: true},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -429,7 +429,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "tenant", Valid: true}, InstanceID: sql.NullString{String: "instanceID", Valid: true},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -439,7 +439,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
{ {
name: "invalid data", name: "invalid data",
args: args{ args: args{
tenantID: "tenant", instanceID: "instanceID",
events: []Command{ events: []Command{
newTestEvent( newTestEvent(
"1", "1",
@ -460,7 +460,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{ events: []Command{
&testEvent{ &testEvent{
BaseEvent: *NewBaseEventForPush( BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"), service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate( NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"), authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"", "",
@ -485,7 +485,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{ events: []Command{
&testEvent{ &testEvent{
BaseEvent: *NewBaseEventForPush( BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"), service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate( NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"), authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"id", "id",
@ -510,7 +510,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{ events: []Command{
&testEvent{ &testEvent{
BaseEvent: *NewBaseEventForPush( BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"), service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate( NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"), authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"id", "id",
@ -535,7 +535,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{ events: []Command{
&testEvent{ &testEvent{
BaseEvent: *NewBaseEventForPush( BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "resourceOwner", "editorUser"), "editorService"), service.WithService(authz.NewMockContext("instanceID", "resourceOwner", "editorUser"), "editorService"),
NewAggregate( NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"), authz.NewMockContext("zitadel", "caos", "adlerhurst"),
"id", "id",
@ -560,7 +560,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
events: []Command{ events: []Command{
&testEvent{ &testEvent{
BaseEvent: *NewBaseEventForPush( BaseEvent: *NewBaseEventForPush(
service.WithService(authz.NewMockContext("tenant", "", "editorUser"), "editorService"), service.WithService(authz.NewMockContext("instanceID", "", "editorUser"), "editorService"),
NewAggregate( NewAggregate(
authz.NewMockContext("zitadel", "", "adlerhurst"), authz.NewMockContext("zitadel", "", "adlerhurst"),
"id", "id",
@ -585,7 +585,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "", Valid: false}, ResourceOwner: sql.NullString{String: "", Valid: false},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -630,7 +630,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -641,7 +641,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -654,7 +654,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -665,7 +665,7 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
events, _, err := commandsToRepository(tt.args.tenantID, tt.args.events) events, _, err := commandsToRepository(tt.args.instanceID, tt.args.events)
if (err != nil) != tt.res.wantErr { if (err != nil) != tt.res.wantErr {
t.Errorf("Eventstore.aggregatesToEvents() error = %v, wantErr %v", err, tt.res.wantErr) t.Errorf("Eventstore.aggregatesToEvents() error = %v, wantErr %v", err, tt.res.wantErr)
return return
@ -772,7 +772,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -816,7 +816,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -827,7 +827,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -882,7 +882,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -893,7 +893,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },
@ -906,7 +906,7 @@ func TestEventstore_Push(t *testing.T) {
EditorService: "editorService", EditorService: "editorService",
EditorUser: "editorUser", EditorUser: "editorUser",
ResourceOwner: sql.NullString{String: "caos", Valid: true}, ResourceOwner: sql.NullString{String: "caos", Valid: true},
Tenant: sql.NullString{String: "zitadel"}, InstanceID: sql.NullString{String: "zitadel"},
Type: "test.event", Type: "test.event",
Version: "v1", Version: "v1",
}, },

View File

@ -8,15 +8,16 @@ import (
"time" "time"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/eventstore"
) )
type mockExpectation func(sqlmock.Sqlmock) type mockExpectation func(sqlmock.Sqlmock)
func expectFailureCount(tableName string, projectionName string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) { func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`). m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\ AND instance_id = \$3\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
WithArgs(projectionName, failedSeq). WithArgs(projectionName, failedSeq, instanceID).
WillReturnRows( WillReturnRows(
sqlmock.NewRows([]string{"failure_count"}). sqlmock.NewRows([]string{"failure_count"}).
AddRow(failureCount), AddRow(failureCount),
@ -24,10 +25,10 @@ func expectFailureCount(tableName string, projectionName string, failedSeq, fail
} }
} }
func expectUpdateFailureCount(tableName string, projectionName string, seq, failureCount uint64) func(sqlmock.Sqlmock) { func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) {
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error\) VALUES \(\$1, \$2, \$3, \$4\)`). m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\)`).
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID).WillReturnResult(sqlmock.NewResult(1, 1))
} }
} }

View File

@ -4,15 +4,16 @@ import (
"database/sql" "database/sql"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/handler" "github.com/caos/zitadel/internal/eventstore/handler"
) )
const ( const (
setFailureCountStmtFormat = "UPSERT INTO %s" + setFailureCountStmtFormat = "UPSERT INTO %s" +
" (projection_name, failed_sequence, failure_count, error)" + " (projection_name, failed_sequence, failure_count, error, instance_id)" +
" VALUES ($1, $2, $3, $4)" " VALUES ($1, $2, $3, $4, $5)"
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2)" + failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
" SELECT IF(" + " SELECT IF(" +
"EXISTS(SELECT failure_count FROM failures)," + "EXISTS(SELECT failure_count FROM failures)," +
" (SELECT failure_count FROM failures)," + " (SELECT failure_count FROM failures)," +
@ -21,31 +22,31 @@ const (
) )
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) { func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
failureCount, err := h.failureCount(tx, stmt.Sequence) failureCount, err := h.failureCount(tx, stmt.Sequence, stmt.InstanceID)
if err != nil { if err != nil {
logging.WithFields("projection", h.ProjectionName, "seq", stmt.Sequence).WithError(err).Warn("unable to get failure count") logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).WithError(err).Warn("unable to get failure count")
return false return false
} }
failureCount += 1 failureCount += 1
err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr) err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr, stmt.InstanceID)
logging.WithFields("projection", h.ProjectionName, "seq", stmt.Sequence).OnError(err).Warn("unable to update failure count") logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).OnError(err).Warn("unable to update failure count")
return failureCount >= h.maxFailureCount return failureCount >= h.maxFailureCount
} }
func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64) (count uint, err error) { func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64, instanceID string) (count uint, err error) {
row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq) row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq, instanceID)
if err = row.Err(); err != nil { if err = row.Err(); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count") return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
} }
if err = row.Scan(&count); err != nil { if err = row.Scan(&count); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scann count") return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
} }
return count, nil return count, nil
} }
func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error) error { func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error, instanceID string) error {
_, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error()) _, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error(), instanceID)
if dbErr != nil { if dbErr != nil {
return errors.ThrowInternal(dbErr, "CRDB-4Ht4x", "set failure count failed") return errors.ThrowInternal(dbErr, "CRDB-4Ht4x", "set failure count failed")
} }

View File

@ -26,7 +26,8 @@ type StatementHandlerConfig struct {
MaxFailureCount uint MaxFailureCount uint
BulkLimit uint64 BulkLimit uint64
Reducers []handler.AggregateReducer Reducers []handler.AggregateReducer
InitCheck *handler.Check
} }
type StatementHandler struct { type StatementHandler struct {
@ -75,6 +76,9 @@ func NewStatementHandler(
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName), Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
} }
err := h.Init(ctx, config.InitCheck)
logging.OnError(err).Fatal("unable to initialize projections")
go h.ProjectionHandler.Process( go h.ProjectionHandler.Process(
ctx, ctx,
h.reduce, h.reduce,
@ -214,7 +218,7 @@ func (h *StatementHandler) executeStmts(
continue continue
} }
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] { if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "seq", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match") logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
break break
} }
err := h.executeStmt(tx, stmt) err := h.executeStmt(tx, stmt)

View File

@ -28,6 +28,7 @@ type testEvent struct {
sequence uint64 sequence uint64
previousSequence uint64 previousSequence uint64
aggregateType eventstore.AggregateType aggregateType eventstore.AggregateType
instanceID string
} }
func (e *testEvent) Sequence() uint64 { func (e *testEvent) Sequence() uint64 {
@ -36,7 +37,8 @@ func (e *testEvent) Sequence() uint64 {
func (e *testEvent) Aggregate() eventstore.Aggregate { func (e *testEvent) Aggregate() eventstore.Aggregate {
return eventstore.Aggregate{ return eventstore.Aggregate{
Type: e.aggregateType, Type: e.aggregateType,
InstanceID: e.instanceID,
} }
} }
@ -786,6 +788,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 5, sequence: 5,
previousSequence: 0, previousSequence: 0,
instanceID: "instanceID",
}, },
[]handler.Column{ []handler.Column{
{ {
@ -798,6 +801,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 6, sequence: 6,
previousSequence: 5, previousSequence: 5,
instanceID: "instanceID",
}, },
[]handler.Column{ []handler.Column{
{ {
@ -810,6 +814,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 7, sequence: 7,
previousSequence: 6, previousSequence: 6,
instanceID: "instanceID",
}, },
[]handler.Column{ []handler.Column{
{ {
@ -830,8 +835,8 @@ func TestStatementHandler_executeStmts(t *testing.T) {
expectSavePoint(), expectSavePoint(),
expectCreateErr("my_projection", []string{"col"}, []string{"$1"}, sql.ErrConnDone), expectCreateErr("my_projection", []string{"col"}, []string{"$1"}, sql.ErrConnDone),
expectSavePointRollback(), expectSavePointRollback(),
expectFailureCount("failed_events", "my_projection", 6, 3), expectFailureCount("failed_events", "my_projection", "instanceID", 6, 3),
expectUpdateFailureCount("failed_events", "my_projection", 6, 4), expectUpdateFailureCount("failed_events", "my_projection", "instanceID", 6, 4),
}, },
idx: 0, idx: 0,
}, },
@ -850,6 +855,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 5, sequence: 5,
previousSequence: 0, previousSequence: 0,
instanceID: "instanceID",
}, },
[]handler.Column{ []handler.Column{
{ {
@ -862,6 +868,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 6, sequence: 6,
previousSequence: 5, previousSequence: 5,
instanceID: "instanceID",
}, },
[]handler.Column{ []handler.Column{
{ {
@ -874,6 +881,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 7, sequence: 7,
previousSequence: 6, previousSequence: 6,
instanceID: "instanceID",
}, },
[]handler.Column{ []handler.Column{
{ {
@ -894,8 +902,8 @@ func TestStatementHandler_executeStmts(t *testing.T) {
expectSavePoint(), expectSavePoint(),
expectCreateErr("my_projection", []string{"col2"}, []string{"$1"}, sql.ErrConnDone), expectCreateErr("my_projection", []string{"col2"}, []string{"$1"}, sql.ErrConnDone),
expectSavePointRollback(), expectSavePointRollback(),
expectFailureCount("failed_events", "my_projection", 6, 4), expectFailureCount("failed_events", "my_projection", "instanceID", 6, 4),
expectUpdateFailureCount("failed_events", "my_projection", 6, 5), expectUpdateFailureCount("failed_events", "my_projection", "instanceID", 6, 5),
expectSavePoint(), expectSavePoint(),
expectCreate("my_projection", []string{"col3"}, []string{"$1"}), expectCreate("my_projection", []string{"col3"}, []string{"$1"}),
expectSavePointRelease(), expectSavePointRelease(),

View File

@ -0,0 +1,320 @@
package crdb
import (
"context"
"errors"
"fmt"
"strings"
"github.com/caos/logging"
"github.com/lib/pq"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/handler"
)
type Table struct {
columns []*Column
primaryKey PrimaryKey
indices []*Index
}
func NewTable(columns []*Column, key PrimaryKey, indices ...*Index) *Table {
return &Table{
columns: columns,
primaryKey: key,
indices: indices,
}
}
type SuffixedTable struct {
Table
suffix string
}
func NewSuffixedTable(columns []*Column, key PrimaryKey, suffix string, indices ...*Index) *SuffixedTable {
return &SuffixedTable{
Table: Table{
columns: columns,
primaryKey: key,
indices: indices,
},
suffix: suffix,
}
}
type Column struct {
Name string
Type ColumnType
nullable bool
defaultValue interface{}
deleteCascade string
}
type ColumnOption func(*Column)
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column {
column := &Column{
Name: name,
Type: columnType,
nullable: false,
defaultValue: nil,
}
for _, opt := range opts {
opt(column)
}
return column
}
func Nullable() ColumnOption {
return func(c *Column) {
c.nullable = true
}
}
func Default(value interface{}) ColumnOption {
return func(c *Column) {
c.defaultValue = value
}
}
func DeleteCascade(column string) ColumnOption {
return func(c *Column) {
c.deleteCascade = column
}
}
type PrimaryKey []string
func NewPrimaryKey(columnNames ...string) PrimaryKey {
return columnNames
}
type ColumnType int32
const (
ColumnTypeText ColumnType = iota
ColumnTypeTextArray
ColumnTypeJSONB
ColumnTypeBytes
ColumnTypeTimestamp
ColumnTypeEnum
ColumnTypeEnumArray
ColumnTypeInt64
ColumnTypeBool
)
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
i := &Index{
Name: name,
Columns: columns,
bucketCount: 0,
}
for _, opt := range opts {
opt(i)
}
return i
}
type Index struct {
Name string
Columns []string
bucketCount uint16
}
type indexOpts func(*Index)
func Hash(bucketsCount uint16) indexOpts {
return func(i *Index) {
i.bucketCount = bucketsCount
}
}
//Init implements handler.Init
func (h *StatementHandler) Init(ctx context.Context, checks ...*handler.Check) error {
for _, check := range checks {
if check == nil || check.IsNoop() {
return nil
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return caos_errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
}
for i, execute := range check.Executes {
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("executing check")
next, err := execute(h.client, h.ProjectionName)
if err != nil {
tx.Rollback()
return err
}
if !next {
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("skipping next check")
break
}
}
if err := tx.Commit(); err != nil {
return err
}
}
return nil
}
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
return createTableStatement(table, config.tableName, "")
}
executes := make([]func(handler.Executer, string) (bool, error), len(table.indices)+1)
executes[0] = execNextIfExists(config, create, opts, true)
for i, index := range table.indices {
executes[i+1] = execNextIfExists(config, createIndexStatement(index), opts, true)
}
return &handler.Check{
Executes: executes,
}
}
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
stmt := createTableStatement(primaryTable, config.tableName, "")
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, true),
},
}
}
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
var stmt string
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
stmt += createViewStatement(config.tableName, selectStmt)
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, false),
},
}
}
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
return func(handler handler.Executer, name string) (bool, error) {
err := exec(config, q, opts)(handler, name)
if isErrAlreadyExists(err) {
return executeNext, nil
}
return false, err
}
}
func isErrAlreadyExists(err error) bool {
caosErr := &caos_errs.CaosError{}
if !errors.As(err, &caosErr) {
return false
}
sqlErr, ok := caosErr.GetParent().(*pq.Error)
if !ok {
return false
}
return sqlErr.Routine == "NewRelationAlreadyExistsError"
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, index := range table.indices {
stmt += fmt.Sprintf(", INDEX %s (%s)", index.Name, strings.Join(index.Columns, ","))
}
return stmt + ");"
}
func createViewStatement(viewName string, selectStmt string) string {
return fmt.Sprintf("CREATE VIEW %s AS %s",
viewName,
selectStmt,
)
}
func createIndexStatement(index *Index) func(config execConfig) string {
return func(config execConfig) string {
stmt := fmt.Sprintf("CREATE INDEX %s ON %s (%s)",
index.Name,
config.tableName,
strings.Join(index.Columns, ","),
)
if index.bucketCount == 0 {
return stmt + ";"
}
return fmt.Sprintf("SET experimental_enable_hash_sharded_indexes=on; %s USING HASH WITH BUCKET_COUNT = %d;",
stmt, index.bucketCount)
}
}
func createColumnsStatement(cols []*Column, tableName string) string {
columns := make([]string, len(cols))
for i, col := range cols {
column := col.Name + " " + columnType(col.Type)
if !col.nullable {
column += " NOT NULL"
}
if col.defaultValue != nil {
column += " DEFAULT " + defaultValue(col.defaultValue)
}
if col.deleteCascade != "" {
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
}
columns[i] = column
}
return strings.Join(columns, ",")
}
func defaultValue(value interface{}) string {
switch v := value.(type) {
case string:
return "'" + v + "'"
default:
return fmt.Sprintf("%v", v)
}
}
func columnType(columnType ColumnType) string {
switch columnType {
case ColumnTypeText:
return "TEXT"
case ColumnTypeTextArray:
return "TEXT[]"
case ColumnTypeTimestamp:
return "TIMESTAMPTZ"
case ColumnTypeEnum:
return "SMALLINT"
case ColumnTypeEnumArray:
return "SMALLINT[]"
case ColumnTypeInt64:
return "BIGINT"
case ColumnTypeBool:
return "BOOLEAN"
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTES"
default:
panic("") //TODO: remove?
return ""
}
}

View File

@ -37,7 +37,7 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
workerName, err := os.Hostname() workerName, err := os.Hostname()
if err != nil || workerName == "" { if err != nil || workerName == "" {
workerName, err = id.SonyFlakeGenerator.Next() workerName, err = id.SonyFlakeGenerator.Next()
logging.Log("CRDB-bdO56").OnError(err).Panic("unable to generate lockID") logging.OnError(err).Panic("unable to generate lockID")
} }
return &locker{ return &locker{
client: client, client: client,

View File

@ -6,7 +6,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/caos/zitadel/internal/errors" caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/handler" "github.com/caos/zitadel/internal/eventstore/handler"
) )
@ -46,6 +46,7 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts), Execute: exec(config, q, opts),
} }
} }
@ -71,6 +72,7 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts), Execute: exec(config, q, opts),
} }
} }
@ -104,6 +106,7 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts), Execute: exec(config, q, opts),
} }
} }
@ -129,6 +132,7 @@ func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition,
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts), Execute: exec(config, q, opts),
} }
} }
@ -138,6 +142,7 @@ func NewNoOpStatement(event eventstore.Event) *handler.Statement {
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
} }
} }
@ -153,6 +158,7 @@ func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Ex
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: multiExec(execs), Execute: multiExec(execs),
} }
} }
@ -278,6 +284,7 @@ func NewCopyStatement(event eventstore.Event, cols []handler.Column, conds []han
AggregateType: event.Aggregate().Type, AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(), Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(), PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts), Execute: exec(config, q, opts),
} }
} }
@ -327,7 +334,7 @@ func exec(config execConfig, q query, opts []execOption) Exec {
} }
if _, err := ex.Exec(q(config), config.args...); err != nil { if _, err := ex.Exec(q(config), config.args...); err != nil {
return errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed") return caos_errs.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
} }
return nil return nil

View File

@ -608,6 +608,7 @@ func TestNewNoOpStatement(t *testing.T) {
aggregateType: "agg", aggregateType: "agg",
sequence: 5, sequence: 5,
previousSequence: 3, previousSequence: 3,
instanceID: "instanceID",
}, },
}, },
want: &handler.Statement{ want: &handler.Statement{
@ -615,6 +616,7 @@ func TestNewNoOpStatement(t *testing.T) {
Execute: nil, Execute: nil,
Sequence: 5, Sequence: 5,
PreviousSequence: 3, PreviousSequence: 3,
InstanceID: "instanceID",
}, },
}, },
} }

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/eventstore"
) )
@ -270,7 +271,7 @@ func (h *ProjectionHandler) fetchBulkStmts(
for _, event := range events { for _, event := range events {
if err = h.processEvent(ctx, event, reduce); err != nil { if err = h.processEvent(ctx, event, reduce); err != nil {
logging.WithFields("projection", h.ProjectionName, "seq", event.Sequence()).WithError(err).Warn("unable to process event in bulk") logging.WithFields("projection", h.ProjectionName, "sequence", event.Sequence(), "instanceID", event.Aggregate().InstanceID).WithError(err).Warn("unable to process event in bulk")
return false, err return false, err
} }
} }

View File

@ -0,0 +1,14 @@
package handler
import "context"
//Init initializes the projection with the given check
type Init func(context.Context, *Check) error
type Check struct {
Executes []func(ex Executer, projectionName string) (bool, error)
}
func (c *Check) IsNoop() bool {
return len(c.Executes) == 0
}

View File

@ -27,6 +27,7 @@ type Statement struct {
AggregateType eventstore.AggregateType AggregateType eventstore.AggregateType
Sequence uint64 Sequence uint64
PreviousSequence uint64 PreviousSequence uint64
InstanceID string
Execute func(ex Executer, projectionName string) error Execute func(ex Executer, projectionName string) error
} }

View File

@ -12,7 +12,7 @@ type ReadModel struct {
ChangeDate time.Time `json:"-"` ChangeDate time.Time `json:"-"`
Events []Event `json:"-"` Events []Event `json:"-"`
ResourceOwner string `json:"-"` ResourceOwner string `json:"-"`
Tenant string `json:"-"` InstanceID string `json:"-"`
} }
//AppendEvents adds all the events to the read model. //AppendEvents adds all the events to the read model.
@ -35,8 +35,8 @@ func (rm *ReadModel) Reduce() error {
if rm.ResourceOwner == "" { if rm.ResourceOwner == "" {
rm.ResourceOwner = rm.Events[0].Aggregate().ResourceOwner rm.ResourceOwner = rm.Events[0].Aggregate().ResourceOwner
} }
if rm.Tenant == "" { if rm.InstanceID == "" {
rm.Tenant = rm.Events[0].Aggregate().Tenant rm.InstanceID = rm.Events[0].Aggregate().InstanceID
} }
if rm.CreationDate.IsZero() { if rm.CreationDate.IsZero() {

View File

@ -56,9 +56,9 @@ type Event struct {
// an aggregate can only be managed by one organisation // an aggregate can only be managed by one organisation
// use the ID of the org // use the ID of the org
ResourceOwner sql.NullString ResourceOwner sql.NullString
//Tenant is the system where this event belongs to //InstanceID is the instance where this event belongs to
// use the ID of the tenant // use the ID of the instance
Tenant sql.NullString InstanceID sql.NullString
} }
//EventType is the description of the change //EventType is the description of the change

View File

@ -66,8 +66,8 @@ const (
FieldSequence FieldSequence
//FieldResourceOwner represents the resource owner field //FieldResourceOwner represents the resource owner field
FieldResourceOwner FieldResourceOwner
//FieldTenant represents the tenant field //FieldInstanceID represents the instance id field
FieldTenant FieldInstanceID
//FieldEditorService represents the editor service field //FieldEditorService represents the editor service field
FieldEditorService FieldEditorService
//FieldEditorUser represents the editor user field //FieldEditorUser represents the editor user field

View File

@ -30,7 +30,7 @@ const (
" SELECT MAX(event_sequence) seq, 1 join_me" + " SELECT MAX(event_sequence) seq, 1 join_me" +
" FROM eventstore.events" + " FROM eventstore.events" +
" WHERE aggregate_type = $2" + " WHERE aggregate_type = $2" +
" AND (CASE WHEN $9::STRING IS NULL THEN tenant is null else tenant = $9::STRING END)" + " AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
") AS agg_type " + ") AS agg_type " +
// combined with // combined with
"LEFT JOIN " + "LEFT JOIN " +
@ -39,7 +39,7 @@ const (
" SELECT event_sequence seq, resource_owner ro, 1 join_me" + " SELECT event_sequence seq, resource_owner ro, 1 join_me" +
" FROM eventstore.events" + " FROM eventstore.events" +
" WHERE aggregate_type = $2 AND aggregate_id = $3" + " WHERE aggregate_type = $2 AND aggregate_id = $3" +
" AND (CASE WHEN $9::STRING IS NULL THEN tenant is null else tenant = $9::STRING END)" + " AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
" ORDER BY event_sequence DESC" + " ORDER BY event_sequence DESC" +
" LIMIT 1" + " LIMIT 1" +
") AS agg USING(join_me)" + ") AS agg USING(join_me)" +
@ -54,7 +54,7 @@ const (
" editor_user," + " editor_user," +
" editor_service," + " editor_service," +
" resource_owner," + " resource_owner," +
" tenant," + " instance_id," +
" event_sequence," + " event_sequence," +
" previous_aggregate_sequence," + " previous_aggregate_sequence," +
" previous_aggregate_type_sequence" + " previous_aggregate_type_sequence" +
@ -70,12 +70,12 @@ const (
" $6::VARCHAR AS editor_user," + " $6::VARCHAR AS editor_user," +
" $7::VARCHAR AS editor_service," + " $7::VARCHAR AS editor_service," +
" IFNULL((resource_owner), $8::VARCHAR) AS resource_owner," + " IFNULL((resource_owner), $8::VARCHAR) AS resource_owner," +
" $9::VARCHAR AS tenant," + " $9::VARCHAR AS instance_id," +
" NEXTVAL(CONCAT('eventstore.', IFNULL($9, 'system'), '_seq'))," + " NEXTVAL(CONCAT('eventstore.', IFNULL($9, 'system'), '_seq'))," +
" aggregate_sequence AS previous_aggregate_sequence," + " aggregate_sequence AS previous_aggregate_sequence," +
" aggregate_type_sequence AS previous_aggregate_type_sequence " + " aggregate_type_sequence AS previous_aggregate_type_sequence " +
"FROM previous_data " + "FROM previous_data " +
"RETURNING id, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, creation_date, resource_owner, tenant" "RETURNING id, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, creation_date, resource_owner, instance_id"
uniqueInsert = `INSERT INTO eventstore.unique_constraints uniqueInsert = `INSERT INTO eventstore.unique_constraints
( (
@ -120,8 +120,8 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
event.EditorUser, event.EditorUser,
event.EditorService, event.EditorService,
event.ResourceOwner, event.ResourceOwner,
event.Tenant, event.InstanceID,
).Scan(&event.ID, &event.Sequence, &previousAggregateSequence, &previousAggregateTypeSequence, &event.CreationDate, &event.ResourceOwner, &event.Tenant) ).Scan(&event.ID, &event.Sequence, &previousAggregateSequence, &previousAggregateTypeSequence, &event.CreationDate, &event.ResourceOwner, &event.InstanceID)
event.PreviousAggregateSequence = uint64(previousAggregateSequence) event.PreviousAggregateSequence = uint64(previousAggregateSequence)
event.PreviousAggregateTypeSequence = uint64(previousAggregateTypeSequence) event.PreviousAggregateTypeSequence = uint64(previousAggregateTypeSequence)
@ -132,7 +132,7 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
"aggregateId", event.AggregateID, "aggregateId", event.AggregateID,
"aggregateType", event.AggregateType, "aggregateType", event.AggregateType,
"eventType", event.Type, "eventType", event.Type,
"tenant", event.Tenant, "instanceID", event.InstanceID,
).WithError(err).Info("query failed") ).WithError(err).Info("query failed")
return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event") return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event")
} }
@ -229,7 +229,7 @@ func (db *CRDB) eventQuery() string {
", editor_service" + ", editor_service" +
", editor_user" + ", editor_user" +
", resource_owner" + ", resource_owner" +
", tenant" + ", instance_id" +
", aggregate_type" + ", aggregate_type" +
", aggregate_id" + ", aggregate_id" +
", aggregate_version" + ", aggregate_version" +
@ -250,8 +250,8 @@ func (db *CRDB) columnName(col repository.Field) string {
return "event_sequence" return "event_sequence"
case repository.FieldResourceOwner: case repository.FieldResourceOwner:
return "resource_owner" return "resource_owner"
case repository.FieldTenant: case repository.FieldInstanceID:
return "tenant" return "instance_id"
case repository.FieldEditorService: case repository.FieldEditorService:
return "editor_service" return "editor_service"
case repository.FieldEditorUser: case repository.FieldEditorUser:

View File

@ -109,7 +109,7 @@ func eventsScanner(scanner scan, dest interface{}) (err error) {
&event.EditorService, &event.EditorService,
&event.EditorUser, &event.EditorUser,
&event.ResourceOwner, &event.ResourceOwner,
&event.Tenant, &event.InstanceID,
&event.AggregateType, &event.AggregateType,
&event.AggregateID, &event.AggregateID,
&event.Version, &event.Version,

View File

@ -130,7 +130,7 @@ func Test_prepareColumns(t *testing.T) {
dest: &[]*repository.Event{}, dest: &[]*repository.Event{},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
expected: []*repository.Event{ expected: []*repository.Event{
{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)}, {AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
}, },
@ -146,7 +146,7 @@ func Test_prepareColumns(t *testing.T) {
dest: []*repository.Event{}, dest: []*repository.Event{},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument, dbErr: errors.IsErrorInvalidArgument,
}, },
}, },
@ -158,7 +158,7 @@ func Test_prepareColumns(t *testing.T) {
dbErr: sql.ErrConnDone, dbErr: sql.ErrConnDone,
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsInternal, dbErr: errors.IsInternal,
}, },
}, },
@ -592,7 +592,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQuery(t, mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")}, []driver.Value{repository.AggregateType("user")},
), ),
}, },
@ -621,7 +621,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQuery(t, mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence LIMIT \$2`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)}, []driver.Value{repository.AggregateType("user"), uint64(5)},
), ),
}, },
@ -650,7 +650,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQuery(t, mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC LIMIT \$2`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)}, []driver.Value{repository.AggregateType("user"), uint64(5)},
), ),
}, },
@ -679,7 +679,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQueryErr(t, mock: newMockClient(t).expectQueryErr(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")}, []driver.Value{repository.AggregateType("user")},
sql.ErrConnDone), sql.ErrConnDone),
}, },
@ -708,7 +708,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQuery(t, mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")}, []driver.Value{repository.AggregateType("user")},
&repository.Event{Sequence: 100}), &repository.Event{Sequence: 100}),
}, },
@ -776,7 +776,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQuery(t, mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) OR \( aggregate_type = \$2 AND aggregate_id = \$3 \) ORDER BY event_sequence DESC LIMIT \$4`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) OR \( aggregate_type = \$2 AND aggregate_id = \$3 \) ORDER BY event_sequence DESC LIMIT \$4`,
[]driver.Value{repository.AggregateType("user"), repository.AggregateType("org"), "asdf42", uint64(5)}, []driver.Value{repository.AggregateType("user"), repository.AggregateType("org"), "asdf42", uint64(5)},
), ),
}, },

View File

@ -12,7 +12,7 @@ type SearchQueryBuilder struct {
limit uint64 limit uint64
desc bool desc bool
resourceOwner string resourceOwner string
tenant string instanceID string
queries []*SearchQuery queries []*SearchQuery
} }
@ -68,9 +68,9 @@ func (factory *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQu
return factory return factory
} }
//Tenant defines the tenant (system) of the events //InstanceID defines the instanceID (system) of the events
func (factory *SearchQueryBuilder) Tenant(tenant string) *SearchQueryBuilder { func (factory *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
factory.tenant = tenant factory.instanceID = instanceID
return factory return factory
} }
@ -145,13 +145,13 @@ func (query *SearchQuery) Builder() *SearchQueryBuilder {
return query.builder return query.builder
} }
func (builder *SearchQueryBuilder) build(tenantID string) (*repository.SearchQuery, error) { func (builder *SearchQueryBuilder) build(instanceID string) (*repository.SearchQuery, error) {
if builder == nil || if builder == nil ||
len(builder.queries) < 1 || len(builder.queries) < 1 ||
builder.columns.Validate() != nil { builder.columns.Validate() != nil {
return nil, errors.ThrowPreconditionFailed(nil, "MODEL-4m9gs", "builder invalid") return nil, errors.ThrowPreconditionFailed(nil, "MODEL-4m9gs", "builder invalid")
} }
builder.tenant = tenantID builder.instanceID = instanceID
filters := make([][]*repository.Filter, len(builder.queries)) filters := make([][]*repository.Filter, len(builder.queries))
for i, query := range builder.queries { for i, query := range builder.queries {
@ -163,7 +163,7 @@ func (builder *SearchQueryBuilder) build(tenantID string) (*repository.SearchQue
query.eventSequenceGreaterFilter, query.eventSequenceGreaterFilter,
query.eventSequenceLessFilter, query.eventSequenceLessFilter,
query.builder.resourceOwnerFilter, query.builder.resourceOwnerFilter,
query.builder.tenantFilter, query.builder.instanceIDFilter,
} { } {
if filter := f(); filter != nil { if filter := f(); filter != nil {
if err := filter.Validate(); err != nil { if err := filter.Validate(); err != nil {
@ -247,11 +247,11 @@ func (builder *SearchQueryBuilder) resourceOwnerFilter() *repository.Filter {
return repository.NewFilter(repository.FieldResourceOwner, builder.resourceOwner, repository.OperationEquals) return repository.NewFilter(repository.FieldResourceOwner, builder.resourceOwner, repository.OperationEquals)
} }
func (builder *SearchQueryBuilder) tenantFilter() *repository.Filter { func (builder *SearchQueryBuilder) instanceIDFilter() *repository.Filter {
if builder.tenant == "" { if builder.instanceID == "" {
return nil return nil
} }
return repository.NewFilter(repository.FieldTenant, builder.tenant, repository.OperationEquals) return repository.NewFilter(repository.FieldInstanceID, builder.instanceID, repository.OperationEquals)
} }
func (query *SearchQuery) eventDataFilter() *repository.Filter { func (query *SearchQuery) eventDataFilter() *repository.Filter {

View File

@ -224,9 +224,9 @@ func TestSearchQuerybuilderSetters(t *testing.T) {
func TestSearchQuerybuilderBuild(t *testing.T) { func TestSearchQuerybuilderBuild(t *testing.T) {
type args struct { type args struct {
columns Columns columns Columns
setters []func(*SearchQueryBuilder) *SearchQueryBuilder setters []func(*SearchQueryBuilder) *SearchQueryBuilder
tenant string instanceID string
} }
type res struct { type res struct {
isErr func(err error) bool isErr func(err error) bool
@ -622,7 +622,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
}, },
}, },
{ {
name: "filter aggregate type and tenant", name: "filter aggregate type and instanceID",
args: args{ args: args{
columns: ColumnsEvent, columns: ColumnsEvent,
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{ setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{
@ -630,7 +630,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
testSetAggregateTypes("user"), testSetAggregateTypes("user"),
), ),
}, },
tenant: "tenant", instanceID: "instanceID",
}, },
res: res{ res: res{
isErr: nil, isErr: nil,
@ -641,7 +641,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
Filters: [][]*repository.Filter{ Filters: [][]*repository.Filter{
{ {
repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals), repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals),
repository.NewFilter(repository.FieldTenant, "tenant", repository.OperationEquals), repository.NewFilter(repository.FieldInstanceID, "instanceID", repository.OperationEquals),
}, },
}, },
}, },
@ -668,7 +668,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
for _, f := range tt.args.setters { for _, f := range tt.args.setters {
builder = f(builder) builder = f(builder)
} }
query, err := builder.build(tt.args.tenant) query, err := builder.build(tt.args.instanceID)
if tt.res.isErr != nil && !tt.res.isErr(err) { if tt.res.isErr != nil && !tt.res.isErr(err) {
t.Errorf("wrong error(%T): %v", err, err) t.Errorf("wrong error(%T): %v", err, err)
return return

View File

@ -122,7 +122,7 @@ func mapEventToV1Event(event Event) *models.Event {
AggregateType: models.AggregateType(event.Aggregate().Type), AggregateType: models.AggregateType(event.Aggregate().Type),
AggregateID: event.Aggregate().ID, AggregateID: event.Aggregate().ID,
ResourceOwner: event.Aggregate().ResourceOwner, ResourceOwner: event.Aggregate().ResourceOwner,
Tenant: event.Aggregate().Tenant, InstanceID: event.Aggregate().InstanceID,
EditorService: event.EditorService(), EditorService: event.EditorService(),
EditorUser: event.EditorUser(), EditorUser: event.EditorUser(),
Data: event.DataAsBytes(), Data: event.DataAsBytes(),

View File

@ -7,15 +7,16 @@ import (
"time" "time"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/eventstore/v1/models"
) )
const ( const (
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1` selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1`
) )
var ( var (
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "tenant", "aggregate_type", "aggregate_id", "aggregate_version"} eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "instance_id", "aggregate_type", "aggregate_id", "aggregate_version"}
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String() expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String()
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String() expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String()
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String() expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
@ -23,7 +24,7 @@ var (
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String() expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String()
expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` + expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` +
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, tenant, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` + `\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` +
`SELECT \$1, \$2, \$3, \$4, COALESCE\(\$5, now\(\)\), \$6, \$7, \$8, \$9, \$10, \$11 ` + `SELECT \$1, \$2, \$3, \$4, COALESCE\(\$5, now\(\)\), \$6, \$7, \$8, \$9, \$10, \$11 ` +
`WHERE EXISTS \(` + `WHERE EXISTS \(` +
`SELECT 1 FROM eventstore\.events WHERE aggregate_type = \$12 AND aggregate_id = \$13 HAVING MAX\(event_sequence\) = \$14 OR \(\$14::BIGINT IS NULL AND COUNT\(\*\) = 0\)\) ` + `SELECT 1 FROM eventstore\.events WHERE aggregate_type = \$12 AND aggregate_id = \$13 HAVING MAX\(event_sequence\) = \$14 OR \(\$14::BIGINT IS NULL AND COUNT\(\*\) = 0\)\) ` +
@ -99,7 +100,7 @@ func (db *dbMock) expectRollback(err error) *dbMock {
func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *dbMock { func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *dbMock {
db.mock.ExpectQuery(expectedInsertStatement). db.mock.ExpectQuery(expectedInsertStatement).
WithArgs( WithArgs(
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.Tenant, Sequence(e.PreviousSequence), e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.InstanceID, Sequence(e.PreviousSequence),
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence), e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
). ).
WillReturnRows( WillReturnRows(
@ -113,7 +114,7 @@ func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *d
func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock { func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock {
db.mock.ExpectQuery(expectedInsertStatement). db.mock.ExpectQuery(expectedInsertStatement).
WithArgs( WithArgs(
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.Tenant, Sequence(e.PreviousSequence), e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.InstanceID, Sequence(e.PreviousSequence),
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence), e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
). ).
WillReturnError(sql.ErrTxDone) WillReturnError(sql.ErrTxDone)
@ -124,7 +125,7 @@ func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock {
func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, eventCount int) *dbMock { func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, eventCount int) *dbMock {
rows := sqlmock.NewRows(eventColumns) rows := sqlmock.NewRows(eventColumns)
for i := 0; i < eventCount; i++ { for i := 0; i < eventCount; i++ {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectQuery(expectedFilterEventsLimitFormat). db.mock.ExpectQuery(expectedFilterEventsLimitFormat).
WithArgs(aggregateType, limit). WithArgs(aggregateType, limit).
@ -135,7 +136,7 @@ func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, ev
func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *dbMock { func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *dbMock {
rows := sqlmock.NewRows(eventColumns) rows := sqlmock.NewRows(eventColumns)
for i := eventCount; i > 0; i-- { for i := eventCount; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectQuery(expectedFilterEventsDescFormat). db.mock.ExpectQuery(expectedFilterEventsDescFormat).
WillReturnRows(rows) WillReturnRows(rows)
@ -145,7 +146,7 @@ func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *
func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID string, limit uint64) *dbMock { func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
rows := sqlmock.NewRows(eventColumns) rows := sqlmock.NewRows(eventColumns)
for i := limit; i > 0; i-- { for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit). db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit).
WithArgs(aggregateType, aggregateID, limit). WithArgs(aggregateType, aggregateID, limit).
@ -156,7 +157,7 @@ func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID
func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregateID string, limit uint64) *dbMock { func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
rows := sqlmock.NewRows(eventColumns) rows := sqlmock.NewRows(eventColumns)
for i := limit; i > 0; i-- { for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "tenant", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit). db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit).
WithArgs(aggregateType, aggregateID, limit). WithArgs(aggregateType, aggregateID, limit).

View File

@ -8,9 +8,10 @@ import (
"strings" "strings"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/lib/pq"
z_errors "github.com/caos/zitadel/internal/errors" z_errors "github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models" es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/lib/pq"
) )
const ( const (
@ -23,7 +24,7 @@ const (
", editor_service" + ", editor_service" +
", editor_user" + ", editor_user" +
", resource_owner" + ", resource_owner" +
", tenant" + ", instance_id" +
", aggregate_type" + ", aggregate_type" +
", aggregate_id" + ", aggregate_id" +
", aggregate_version" + ", aggregate_version" +
@ -117,7 +118,7 @@ func prepareColumns(columns es_models.Columns) (string, func(s scan, dest interf
&event.EditorService, &event.EditorService,
&event.EditorUser, &event.EditorUser,
&event.ResourceOwner, &event.ResourceOwner,
&event.Tenant, &event.InstanceID,
&event.AggregateType, &event.AggregateType,
&event.AggregateID, &event.AggregateID,
&event.AggregateVersion, &event.AggregateVersion,
@ -177,8 +178,8 @@ func getField(field es_models.Field) string {
return "event_sequence" return "event_sequence"
case es_models.Field_ResourceOwner: case es_models.Field_ResourceOwner:
return "resource_owner" return "resource_owner"
case es_models.Field_Tenant: case es_models.Field_InstanceID:
return "tenant" return "instance_id"
case es_models.Field_EditorService: case es_models.Field_EditorService:
return "editor_service" return "editor_service"
case es_models.Field_EditorUser: case es_models.Field_EditorUser:

View File

@ -6,9 +6,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/lib/pq"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models" es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/lib/pq"
) )
func Test_numberPlaceholder(t *testing.T) { func Test_numberPlaceholder(t *testing.T) {
@ -80,7 +81,7 @@ func Test_getField(t *testing.T) {
es_models.Field_AggregateID: "aggregate_id", es_models.Field_AggregateID: "aggregate_id",
es_models.Field_LatestSequence: "event_sequence", es_models.Field_LatestSequence: "event_sequence",
es_models.Field_ResourceOwner: "resource_owner", es_models.Field_ResourceOwner: "resource_owner",
es_models.Field_Tenant: "tenant", es_models.Field_InstanceID: "instance_id",
es_models.Field_EditorService: "editor_service", es_models.Field_EditorService: "editor_service",
es_models.Field_EditorUser: "editor_user", es_models.Field_EditorUser: "editor_user",
es_models.Field_EventType: "event_type", es_models.Field_EventType: "event_type",
@ -235,7 +236,7 @@ func Test_prepareColumns(t *testing.T) {
dest: new(es_models.Event), dest: new(es_models.Event),
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbRow: []interface{}{time.Time{}, es_models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", "", es_models.AggregateType("user"), "hodor", es_models.Version("")}, dbRow: []interface{}{time.Time{}, es_models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", "", es_models.AggregateType("user"), "hodor", es_models.Version("")},
expected: es_models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)}, expected: es_models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
}, },
@ -247,7 +248,7 @@ func Test_prepareColumns(t *testing.T) {
dest: new(uint64), dest: new(uint64),
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument, dbErr: errors.IsErrorInvalidArgument,
}, },
}, },
@ -259,7 +260,7 @@ func Test_prepareColumns(t *testing.T) {
dbErr: sql.ErrConnDone, dbErr: sql.ErrConnDone,
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsInternal, dbErr: errors.IsInternal,
}, },
}, },
@ -430,7 +431,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(), queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
rowScanner: true, rowScanner: true,
values: []interface{}{es_models.AggregateType("user")}, values: []interface{}{es_models.AggregateType("user")},
}, },
@ -441,7 +442,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5), queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
rowScanner: true, rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)}, values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5, limit: 5,
@ -453,7 +454,7 @@ func Test_buildQuery(t *testing.T) {
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(), queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, tenant, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2", query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true, rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)}, values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5, limit: 5,

View File

@ -23,7 +23,7 @@ type Aggregate struct {
editorService string editorService string
editorUser string editorUser string
resourceOwner string resourceOwner string
tenant string instanceID string
Events []*Event Events []*Event
Precondition *precondition Precondition *precondition
} }
@ -56,7 +56,7 @@ func (a *Aggregate) AppendEvent(typ EventType, payload interface{}) (*Aggregate,
EditorService: a.editorService, EditorService: a.editorService,
EditorUser: a.editorUser, EditorUser: a.editorUser,
ResourceOwner: a.resourceOwner, ResourceOwner: a.resourceOwner,
Tenant: a.tenant, InstanceID: a.instanceID,
} }
a.Events = append(a.Events, e) a.Events = append(a.Events, e)

View File

@ -18,9 +18,10 @@ type option func(*Aggregate)
func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ AggregateType, version Version, previousSequence uint64, opts ...option) (*Aggregate, error) { func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ AggregateType, version Version, previousSequence uint64, opts ...option) (*Aggregate, error) {
ctxData := authz.GetCtxData(ctx) ctxData := authz.GetCtxData(ctx)
instance := authz.GetInstance(ctx)
editorUser := ctxData.UserID editorUser := ctxData.UserID
resourceOwner := ctxData.OrgID resourceOwner := ctxData.OrgID
tenant := ctxData.TenantID instanceID := instance.ID
aggregate := &Aggregate{ aggregate := &Aggregate{
ID: id, ID: id,
@ -31,7 +32,7 @@ func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ Aggr
editorService: c.serviceName, editorService: c.serviceName,
editorUser: editorUser, editorUser: editorUser,
resourceOwner: resourceOwner, resourceOwner: resourceOwner,
tenant: tenant, instanceID: instanceID,
} }
for _, opt := range opts { for _, opt := range opts {

View File

@ -28,7 +28,7 @@ type Event struct {
EditorService string EditorService string
EditorUser string EditorUser string
ResourceOwner string ResourceOwner string
Tenant string InstanceID string
} }
func eventData(i interface{}) ([]byte, error) { func eventData(i interface{}) ([]byte, error) {

View File

@ -11,5 +11,5 @@ const (
Field_EditorUser Field_EditorUser
Field_EventType Field_EventType
Field_CreationDate Field_CreationDate
Field_Tenant Field_InstanceID
) )

View File

@ -8,7 +8,7 @@ type ObjectRoot struct {
AggregateID string `json:"-"` AggregateID string `json:"-"`
Sequence uint64 `json:"-"` Sequence uint64 `json:"-"`
ResourceOwner string `json:"-"` ResourceOwner string `json:"-"`
Tenant string `json:"-"` InstanceID string `json:"-"`
CreationDate time.Time `json:"-"` CreationDate time.Time `json:"-"`
ChangeDate time.Time `json:"-"` ChangeDate time.Time `json:"-"`
} }
@ -22,8 +22,8 @@ func (o *ObjectRoot) AppendEvent(event *Event) {
if o.ResourceOwner == "" { if o.ResourceOwner == "" {
o.ResourceOwner = event.ResourceOwner o.ResourceOwner = event.ResourceOwner
} }
if o.Tenant == "" { if o.InstanceID == "" {
o.Tenant = event.Tenant o.InstanceID = event.InstanceID
} }
o.ChangeDate = event.CreationDate o.ChangeDate = event.CreationDate

View File

@ -4,6 +4,7 @@ import (
"time" "time"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
) )
@ -17,7 +18,7 @@ type SearchQueryFactory struct {
sequenceTo uint64 sequenceTo uint64
eventTypes []EventType eventTypes []EventType
resourceOwner string resourceOwner string
tenant string instanceID string
creationDate time.Time creationDate time.Time
} }
@ -63,8 +64,8 @@ func FactoryFromSearchQuery(query *SearchQuery) *SearchQueryFactory {
} }
case Field_ResourceOwner: case Field_ResourceOwner:
factory = factory.ResourceOwner(filter.value.(string)) factory = factory.ResourceOwner(filter.value.(string))
case Field_Tenant: case Field_InstanceID:
factory = factory.Tenant(filter.value.(string)) factory = factory.InstanceID(filter.value.(string))
case Field_EventType: case Field_EventType:
factory = factory.EventTypes(filter.value.([]EventType)...) factory = factory.EventTypes(filter.value.([]EventType)...)
case Field_EditorService, Field_EditorUser: case Field_EditorService, Field_EditorUser:
@ -123,8 +124,8 @@ func (factory *SearchQueryFactory) ResourceOwner(resourceOwner string) *SearchQu
return factory return factory
} }
func (factory *SearchQueryFactory) Tenant(tenant string) *SearchQueryFactory { func (factory *SearchQueryFactory) InstanceID(instanceID string) *SearchQueryFactory {
factory.tenant = tenant factory.instanceID = instanceID
return factory return factory
} }
@ -159,7 +160,7 @@ func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
factory.sequenceToFilter, factory.sequenceToFilter,
factory.eventTypeFilter, factory.eventTypeFilter,
factory.resourceOwnerFilter, factory.resourceOwnerFilter,
factory.tenantFilter, factory.instanceIDFilter,
factory.creationDateNewerFilter, factory.creationDateNewerFilter,
} { } {
if filter := f(); filter != nil { if filter := f(); filter != nil {
@ -231,11 +232,11 @@ func (factory *SearchQueryFactory) resourceOwnerFilter() *Filter {
return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals) return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals)
} }
func (factory *SearchQueryFactory) tenantFilter() *Filter { func (factory *SearchQueryFactory) instanceIDFilter() *Filter {
if factory.tenant == "" { if factory.instanceID == "" {
return nil return nil
} }
return NewFilter(Field_Tenant, factory.tenant, Operation_Equals) return NewFilter(Field_InstanceID, factory.instanceID, Operation_Equals)
} }
func (factory *SearchQueryFactory) creationDateNewerFilter() *Filter { func (factory *SearchQueryFactory) creationDateNewerFilter() *Filter {

View File

@ -69,8 +69,8 @@ func (q *SearchQuery) ResourceOwnerFilter(resourceOwner string) *SearchQuery {
return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals)) return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals))
} }
func (q *SearchQuery) TenantFilter(tenant string) *SearchQuery { func (q *SearchQuery) InstanceIDFilter(instanceID string) *SearchQuery {
return q.setFilter(NewFilter(Field_Tenant, tenant, Operation_Equals)) return q.setFilter(NewFilter(Field_InstanceID, instanceID, Operation_Equals))
} }
func (q *SearchQuery) CreationDateNewerFilter(time time.Time) *SearchQuery { func (q *SearchQuery) CreationDateNewerFilter(time time.Time) *SearchQuery {

View File

@ -54,7 +54,7 @@ func ReduceEvent(handler Handler, event *models.Event) {
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery) unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
if err != nil { if err != nil {
logging.LogWithFields("HANDL-L6YH1", "seq", event.Sequence).Warn("filter failed") logging.WithFields("HANDL-L6YH1", "sequence", event.Sequence).Warn("filter failed")
return return
} }
@ -74,12 +74,12 @@ func ReduceEvent(handler Handler, event *models.Event) {
} }
err = handler.Reduce(unprocessedEvent) err = handler.Reduce(unprocessedEvent)
logging.LogWithFields("HANDL-V42TI", "seq", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed") logging.WithFields("HANDL-V42TI", "sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
} }
if len(unprocessedEvents) == eventLimit { if len(unprocessedEvents) == eventLimit {
logging.LogWithFields("QUERY-BSqe9", "seq", event.Sequence).Warn("didnt process event") logging.WithFields("QUERY-BSqe9", "sequence", event.Sequence).Warn("didnt process event")
return return
} }
err = handler.Reduce(event) err = handler.Reduce(event)
logging.LogWithFields("HANDL-wQDL2", "seq", event.Sequence).OnError(err).Warn("reduce failed") logging.WithFields("HANDL-wQDL2", "sequence", event.Sequence).OnError(err).Warn("reduce failed")
} }

View File

@ -10,7 +10,7 @@ type WriteModel struct {
ProcessedSequence uint64 `json:"-"` ProcessedSequence uint64 `json:"-"`
Events []Event `json:"-"` Events []Event `json:"-"`
ResourceOwner string `json:"-"` ResourceOwner string `json:"-"`
Tenant string `json:"-"` InstanceID string `json:"-"`
ChangeDate time.Time `json:"-"` ChangeDate time.Time `json:"-"`
} }
@ -33,8 +33,8 @@ func (wm *WriteModel) Reduce() error {
if wm.ResourceOwner == "" { if wm.ResourceOwner == "" {
wm.ResourceOwner = wm.Events[0].Aggregate().ResourceOwner wm.ResourceOwner = wm.Events[0].Aggregate().ResourceOwner
} }
if wm.Tenant == "" { if wm.InstanceID == "" {
wm.Tenant = wm.Events[0].Aggregate().Tenant wm.InstanceID = wm.Events[0].Aggregate().InstanceID
} }
wm.ProcessedSequence = wm.Events[len(wm.Events)-1].Sequence() wm.ProcessedSequence = wm.Events[len(wm.Events)-1].Sequence()

View File

@ -51,6 +51,7 @@ const (
IDPConfigSearchKeyAggregateID IDPConfigSearchKeyAggregateID
IDPConfigSearchKeyIdpConfigID IDPConfigSearchKeyIdpConfigID
IDPConfigSearchKeyIdpProviderType IDPConfigSearchKeyIdpProviderType
IDPConfigSearchKeyInstanceID
) )
type IDPConfigSearchQuery struct { type IDPConfigSearchQuery struct {

View File

@ -48,6 +48,7 @@ const (
LabelPolicySearchKeyUnspecified LabelPolicySearchKey = iota LabelPolicySearchKeyUnspecified LabelPolicySearchKey = iota
LabelPolicySearchKeyAggregateID LabelPolicySearchKeyAggregateID
LabelPolicySearchKeyState LabelPolicySearchKeyState
LabelPolicySearchKeyInstanceID
) )
type LabelPolicySearchQuery struct { type LabelPolicySearchQuery struct {

View File

@ -24,6 +24,7 @@ const (
IDPConfigKeyAggregateID = "aggregate_id" IDPConfigKeyAggregateID = "aggregate_id"
IDPConfigKeyName = "name" IDPConfigKeyName = "name"
IDPConfigKeyProviderType = "idp_provider_type" IDPConfigKeyProviderType = "idp_provider_type"
IDPConfigKeyInstanceID = "instance_id"
) )
type IDPConfigView struct { type IDPConfigView struct {
@ -50,7 +51,8 @@ type IDPConfigView struct {
JWTKeysEndpoint string `json:"keysEndpoint" gorm:"jwt_keys_endpoint"` JWTKeysEndpoint string `json:"keysEndpoint" gorm:"jwt_keys_endpoint"`
JWTHeaderName string `json:"headerName" gorm:"jwt_header_name"` JWTHeaderName string `json:"headerName" gorm:"jwt_header_name"`
Sequence uint64 `json:"-" gorm:"column:sequence"` Sequence uint64 `json:"-" gorm:"column:sequence"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
} }
func IDPConfigViewToModel(idp *IDPConfigView) *model.IDPConfigView { func IDPConfigViewToModel(idp *IDPConfigView) *model.IDPConfigView {
@ -120,6 +122,7 @@ func (i *IDPConfigView) AppendEvent(providerType model.IDPProviderType, event *m
func (r *IDPConfigView) setRootData(event *models.Event) { func (r *IDPConfigView) setRootData(event *models.Event) {
r.AggregateID = event.AggregateID r.AggregateID = event.AggregateID
r.InstanceID = event.InstanceID
} }
func (r *IDPConfigView) SetData(event *models.Event) error { func (r *IDPConfigView) SetData(event *models.Event) error {

View File

@ -59,6 +59,8 @@ func (key IDPConfigSearchKey) ToColumnName() string {
return IDPConfigKeyName return IDPConfigKeyName
case iam_model.IDPConfigSearchKeyIdpProviderType: case iam_model.IDPConfigSearchKeyIdpProviderType:
return IDPConfigKeyProviderType return IDPConfigKeyProviderType
case iam_model.IDPConfigSearchKeyInstanceID:
return IDPConfigKeyInstanceID
default: default:
return "" return ""
} }

View File

@ -2,12 +2,14 @@ package model
import ( import (
"encoding/json" "encoding/json"
org_es_model "github.com/caos/zitadel/internal/org/repository/eventsourcing/model"
"time" "time"
org_es_model "github.com/caos/zitadel/internal/org/repository/eventsourcing/model"
es_model "github.com/caos/zitadel/internal/iam/repository/eventsourcing/model" es_model "github.com/caos/zitadel/internal/iam/repository/eventsourcing/model"
"github.com/caos/logging" "github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors" caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/iam/model" "github.com/caos/zitadel/internal/iam/model"
@ -32,7 +34,8 @@ type IDPProviderView struct {
IDPProviderType int32 `json:"idpProviderType" gorm:"column:idp_provider_type"` IDPProviderType int32 `json:"idpProviderType" gorm:"column:idp_provider_type"`
IDPState int32 `json:"-" gorm:"column:idp_state"` IDPState int32 `json:"-" gorm:"column:idp_state"`
Sequence uint64 `json:"-" gorm:"column:sequence"` Sequence uint64 `json:"-" gorm:"column:sequence"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
} }
func IDPProviderViewFromModel(provider *model.IDPProviderView) *IDPProviderView { func IDPProviderViewFromModel(provider *model.IDPProviderView) *IDPProviderView {
@ -87,6 +90,7 @@ func (i *IDPProviderView) AppendEvent(event *models.Event) (err error) {
func (r *IDPProviderView) setRootData(event *models.Event) { func (r *IDPProviderView) setRootData(event *models.Event) {
r.AggregateID = event.AggregateID r.AggregateID = event.AggregateID
r.InstanceID = event.InstanceID
} }
func (r *IDPProviderView) SetData(event *models.Event) error { func (r *IDPProviderView) SetData(event *models.Event) error {

View File

@ -19,6 +19,7 @@ import (
const ( const (
LabelPolicyKeyAggregateID = "aggregate_id" LabelPolicyKeyAggregateID = "aggregate_id"
LabelPolicyKeyState = "label_policy_state" LabelPolicyKeyState = "label_policy_state"
LabelPolicyKeyInstanceID = "instance_id"
) )
type LabelPolicyView struct { type LabelPolicyView struct {
@ -45,7 +46,8 @@ type LabelPolicyView struct {
DisableWatermark bool `json:"disableWatermark" gorm:"column:disable_watermark"` DisableWatermark bool `json:"disableWatermark" gorm:"column:disable_watermark"`
Default bool `json:"-" gorm:"-"` Default bool `json:"-" gorm:"-"`
Sequence uint64 `json:"-" gorm:"column:sequence"` Sequence uint64 `json:"-" gorm:"column:sequence"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
} }
type AssetView struct { type AssetView struct {
@ -189,6 +191,7 @@ func (i *LabelPolicyView) AppendEvent(event *models.Event) (err error) {
func (r *LabelPolicyView) setRootData(event *models.Event) { func (r *LabelPolicyView) setRootData(event *models.Event) {
r.AggregateID = event.AggregateID r.AggregateID = event.AggregateID
r.InstanceID = event.InstanceID
} }
func (r *LabelPolicyView) SetData(event *models.Event) error { func (r *LabelPolicyView) SetData(event *models.Event) error {

View File

@ -55,6 +55,9 @@ func (key LabelPolicySearchKey) ToColumnName() string {
return LabelPolicyKeyAggregateID return LabelPolicyKeyAggregateID
case iam_model.LabelPolicySearchKeyState: case iam_model.LabelPolicySearchKeyState:
return LabelPolicyKeyState return LabelPolicyKeyState
case iam_model.LabelPolicySearchKeyInstanceID:
return LabelPolicyKeyInstanceID
default: default:
return "" return ""
} }

View File

@ -0,0 +1,77 @@
package migration
import "github.com/caos/zitadel/internal/eventstore"
//SetupStep is the command pushed on the eventstore
type SetupStep struct {
typ eventstore.EventType
migration Migration
Name string `json:"name"`
Error error `json:"error,omitempty"`
done bool
}
func setupStartedCmd(migration Migration) eventstore.Command {
return &SetupStep{
migration: migration,
typ: startedType,
Name: migration.String(),
}
}
func setupDoneCmd(migration Migration, err error) eventstore.Command {
s := &SetupStep{
typ: doneType,
migration: migration,
Name: migration.String(),
}
if err != nil {
s.typ = failedType
s.Error = err
}
return s
}
func (s *SetupStep) Aggregate() eventstore.Aggregate {
return eventstore.Aggregate{
ID: aggregateID,
Type: aggregateType,
ResourceOwner: "SYSTEM",
Version: "v1",
}
}
func (s *SetupStep) EditorService() string {
return "system"
}
func (s *SetupStep) EditorUser() string {
return "system"
}
func (s *SetupStep) Type() eventstore.EventType {
return s.typ
}
func (s *SetupStep) Data() interface{} {
return s
}
func (s *SetupStep) UniqueConstraints() []*eventstore.EventUniqueConstraint {
switch s.typ {
case startedType:
return []*eventstore.EventUniqueConstraint{
eventstore.NewAddEventUniqueConstraint("migration_started", s.migration.String(), "Errors.Step.Started.AlreadyExists"),
}
case failedType:
return []*eventstore.EventUniqueConstraint{
eventstore.NewRemoveEventUniqueConstraint("migration_started", s.migration.String()),
}
default:
return []*eventstore.EventUniqueConstraint{
eventstore.NewAddEventUniqueConstraint("migration_done", s.migration.String(), "Errors.Step.Done.AlreadyExists"),
}
}
}

View File

@ -0,0 +1,84 @@
package migration
import (
"context"
"encoding/json"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
)
const (
startedType = eventstore.EventType("system.migration.started")
doneType = eventstore.EventType("system.migration.done")
failedType = eventstore.EventType("system.migration.failed")
aggregateType = eventstore.AggregateType("system")
aggregateID = "SYSTEM"
)
type Migration interface {
String() string
Execute(context.Context) error
}
func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) {
if should, err := shouldExec(ctx, es, migration); !should || err != nil {
return err
}
if _, err = es.Push(ctx, setupStartedCmd(migration)); err != nil {
return err
}
err = migration.Execute(ctx)
logging.OnError(err).Error("migration failed")
_, err = es.Push(ctx, setupDoneCmd(migration, err))
return err
}
func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migration) (should bool, err error) {
events, err := es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AddQuery().
AggregateTypes(aggregateType).
AggregateIDs(aggregateID).
EventTypes(startedType, doneType, failedType).
Builder())
if err != nil {
return false, err
}
if len(events) == 0 {
return true, nil
}
if events[len(events)-1].Type() == startedType {
return false, nil
}
for _, e := range events {
step := new(SetupStep)
err = json.Unmarshal(e.DataAsBytes(), step)
if err != nil {
return false, err
}
if step.Name != migration.String() {
continue
}
switch e.Type() {
case startedType, doneType:
//TODO: if started should we wait until done/failed?
return false, nil
case failedType:
//TODO: how to allow retries?
logging.WithFields("migration", migration.String()).Error("failed before")
return false, errors.ThrowInternal(nil, "MIGRA-mjI2E", "migration failed before")
}
}
return true, nil
}

View File

@ -26,6 +26,7 @@ const (
OrgProjectMappingSearchKeyProjectID OrgProjectMappingSearchKeyProjectID
OrgProjectMappingSearchKeyOrgID OrgProjectMappingSearchKeyOrgID
OrgProjectMappingSearchKeyProjectGrantID OrgProjectMappingSearchKeyProjectGrantID
OrgProjectMappingSearchKeyInstanceID
) )
type OrgProjectMappingViewSearchQuery struct { type OrgProjectMappingViewSearchQuery struct {

View File

@ -4,10 +4,12 @@ const (
OrgProjectMappingKeyProjectID = "project_id" OrgProjectMappingKeyProjectID = "project_id"
OrgProjectMappingKeyOrgID = "org_id" OrgProjectMappingKeyOrgID = "org_id"
OrgProjectMappingKeyProjectGrantID = "project_grant_id" OrgProjectMappingKeyProjectGrantID = "project_grant_id"
OrgProjectMappingKeyInstanceID = "instance_id"
) )
type OrgProjectMapping struct { type OrgProjectMapping struct {
ProjectID string `json:"-" gorm:"column:project_id;primary_key"` ProjectID string `json:"-" gorm:"column:project_id;primary_key"`
OrgID string `json:"-" gorm:"column:org_id;primary_key"` OrgID string `json:"-" gorm:"column:org_id;primary_key"`
ProjectGrantID string `json:"-" gorm:"column:project_grant_id;"` ProjectGrantID string `json:"-" gorm:"column:project_grant_id"`
InstanceID string `json:"instanceID" gorm:"column:instance_id"`
} }

View File

@ -57,6 +57,8 @@ func (key OrgProjectMappingSearchKey) ToColumnName() string {
return OrgProjectMappingKeyProjectID return OrgProjectMappingKeyProjectID
case proj_model.OrgProjectMappingSearchKeyProjectGrantID: case proj_model.OrgProjectMappingSearchKeyProjectGrantID:
return OrgProjectMappingKeyProjectGrantID return OrgProjectMappingKeyProjectGrantID
case proj_model.OrgProjectMappingSearchKeyInstanceID:
return OrgProjectMappingKeyInstanceID
default: default:
return "" return ""
} }

View File

@ -2,13 +2,15 @@ package model
import ( import (
"encoding/json" "encoding/json"
"time"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/lib/pq"
caos_errs "github.com/caos/zitadel/internal/errors" caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/project/model" "github.com/caos/zitadel/internal/project/model"
es_model "github.com/caos/zitadel/internal/project/repository/eventsourcing/model" es_model "github.com/caos/zitadel/internal/project/repository/eventsourcing/model"
"github.com/lib/pq"
"time"
) )
const ( const (
@ -39,6 +41,7 @@ type ProjectGrant struct {
GrantID string `json:"grantId"` GrantID string `json:"grantId"`
GrantedOrgID string `json:"grantedOrgId"` GrantedOrgID string `json:"grantedOrgId"`
RoleKeys []string `json:"roleKeys"` RoleKeys []string `json:"roleKeys"`
InstanceID string `json:"instanceID"`
} }
func ProjectGrantFromModel(project *model.ProjectGrantView) *ProjectGrantView { func ProjectGrantFromModel(project *model.ProjectGrantView) *ProjectGrantView {

View File

@ -8,6 +8,8 @@ import (
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query/projection" "github.com/caos/zitadel/internal/query/projection"
@ -33,6 +35,10 @@ var (
name: projection.ActionResourceOwnerCol, name: projection.ActionResourceOwnerCol,
table: actionTable, table: actionTable,
} }
ActionColumnInstanceID = Column{
name: projection.ActionInstanceIDCol,
table: actionTable,
}
ActionColumnSequence = Column{ ActionColumnSequence = Column{
name: projection.ActionSequenceCol, name: projection.ActionSequenceCol,
table: actionTable, table: actionTable,
@ -93,7 +99,11 @@ func (q *ActionSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQueries) (actions *Actions, err error) { func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQueries) (actions *Actions, err error) {
query, scan := prepareActionsQuery() query, scan := prepareActionsQuery()
stmt, args, err := queries.toQuery(query).ToSql() stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
ActionColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}).
ToSql()
if err != nil { if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-SDgwg", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-SDgwg", "Errors.Query.InvalidRequest")
} }
@ -116,6 +126,7 @@ func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string) (*
sq.Eq{ sq.Eq{
ActionColumnID.identifier(): id, ActionColumnID.identifier(): id,
ActionColumnResourceOwner.identifier(): orgID, ActionColumnResourceOwner.identifier(): orgID,
ActionColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}).ToSql() }).ToSql()
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement")

View File

@ -8,6 +8,8 @@ import (
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query/projection" "github.com/caos/zitadel/internal/query/projection"
@ -21,6 +23,14 @@ var (
name: projection.FlowTypeCol, name: projection.FlowTypeCol,
table: flowsTriggersTable, table: flowsTriggersTable,
} }
FlowsTriggersColumnChangeDate = Column{
name: projection.FlowChangeDateCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnSequence = Column{
name: projection.FlowSequenceCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnTriggerType = Column{ FlowsTriggersColumnTriggerType = Column{
name: projection.FlowTriggerTypeCol, name: projection.FlowTriggerTypeCol,
table: flowsTriggersTable, table: flowsTriggersTable,
@ -29,6 +39,10 @@ var (
name: projection.FlowResourceOwnerCol, name: projection.FlowResourceOwnerCol,
table: flowsTriggersTable, table: flowsTriggersTable,
} }
FlowsTriggersColumnInstanceID = Column{
name: projection.FlowInstanceIDCol,
table: flowsTriggersTable,
}
FlowsTriggersColumnTriggerSequence = Column{ FlowsTriggersColumnTriggerSequence = Column{
name: projection.FlowActionTriggerSequenceCol, name: projection.FlowActionTriggerSequenceCol,
table: flowsTriggersTable, table: flowsTriggersTable,
@ -40,10 +54,9 @@ var (
) )
type Flow struct { type Flow struct {
CreationDate time.Time //TODO: add in projection ChangeDate time.Time
ChangeDate time.Time //TODO: add in projection ResourceOwner string
ResourceOwner string //TODO: add in projection Sequence uint64
Sequence uint64 //TODO: add in projection
Type domain.FlowType Type domain.FlowType
TriggerActions map[domain.TriggerType][]*Action TriggerActions map[domain.TriggerType][]*Action
@ -55,6 +68,7 @@ func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID s
sq.Eq{ sq.Eq{
FlowsTriggersColumnFlowType.identifier(): flowType, FlowsTriggersColumnFlowType.identifier(): flowType,
FlowsTriggersColumnResourceOwner.identifier(): orgID, FlowsTriggersColumnResourceOwner.identifier(): orgID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}).ToSql() }).ToSql()
if err != nil { if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-HBRh3", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-HBRh3", "Errors.Query.InvalidRequest")
@ -74,6 +88,7 @@ func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flow
FlowsTriggersColumnFlowType.identifier(): flowType, FlowsTriggersColumnFlowType.identifier(): flowType,
FlowsTriggersColumnTriggerType.identifier(): triggerType, FlowsTriggersColumnTriggerType.identifier(): triggerType,
FlowsTriggersColumnResourceOwner.identifier(): orgID, FlowsTriggersColumnResourceOwner.identifier(): orgID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
ActionColumnState.identifier(): domain.ActionStateActive, ActionColumnState.identifier(): domain.ActionStateActive,
}, },
).ToSql() ).ToSql()
@ -92,7 +107,8 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string) (
stmt, scan := prepareFlowTypesQuery() stmt, scan := prepareFlowTypesQuery()
query, args, err := stmt.Where( query, args, err := stmt.Where(
sq.Eq{ sq.Eq{
FlowsTriggersColumnActionID.identifier(): actionID, FlowsTriggersColumnActionID.identifier(): actionID,
FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).ID,
}, },
).ToSql() ).ToSql()
if err != nil { if err != nil {
@ -185,6 +201,9 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
FlowsTriggersColumnTriggerType.identifier(), FlowsTriggersColumnTriggerType.identifier(),
FlowsTriggersColumnTriggerSequence.identifier(), FlowsTriggersColumnTriggerSequence.identifier(),
FlowsTriggersColumnFlowType.identifier(), FlowsTriggersColumnFlowType.identifier(),
FlowsTriggersColumnChangeDate.identifier(),
FlowsTriggersColumnSequence.identifier(),
FlowsTriggersColumnResourceOwner.identifier(),
). ).
From(flowsTriggersTable.name). From(flowsTriggersTable.name).
LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)). LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)).
@ -194,7 +213,6 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
TriggerActions: make(map[domain.TriggerType][]*Action), TriggerActions: make(map[domain.TriggerType][]*Action),
} }
for rows.Next() { for rows.Next() {
// action := new(Action)
var ( var (
actionID sql.NullString actionID sql.NullString
actionCreationDate pq.NullTime actionCreationDate pq.NullTime
@ -207,7 +225,6 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
triggerType domain.TriggerType triggerType domain.TriggerType
triggerSequence int triggerSequence int
flowType domain.FlowType
) )
err := rows.Scan( err := rows.Scan(
&actionID, &actionID,
@ -220,12 +237,14 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) {
&actionScript, &actionScript,
&triggerType, &triggerType,
&triggerSequence, &triggerSequence,
&flowType, &flow.Type,
&flow.ChangeDate,
&flow.Sequence,
&flow.ResourceOwner,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
flow.Type = flowType
if !actionID.Valid { if !actionID.Valid {
continue continue
} }

Some files were not shown because too many files have changed in this diff Show More