From 278a278a5b2ab61739584b57b23e3927d736dace Mon Sep 17 00:00:00 2001 From: Silvan Date: Fri, 17 Dec 2021 16:28:41 +0100 Subject: [PATCH] fix(authz): retry search memberships if no memberships found (#2869) --- internal/api/authz/context.go | 11 ++-- internal/api/authz/permissions.go | 15 +++++- internal/api/authz/retry.go | 15 ++++++ internal/api/authz/retry_test.go | 88 +++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 9 deletions(-) create mode 100644 internal/api/authz/retry.go create mode 100644 internal/api/authz/retry_test.go diff --git a/internal/api/authz/context.go b/internal/api/authz/context.go index 4015277c2f..c16f4e2c2f 100644 --- a/internal/api/authz/context.go +++ b/internal/api/authz/context.go @@ -2,7 +2,6 @@ package authz import ( "context" - "time" "github.com/caos/zitadel/internal/api/grpc" http_util "github.com/caos/zitadel/internal/api/http" @@ -84,13 +83,9 @@ func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID string, t *To err = t.ExistsOrg(ctx, orgID) if err != nil { - for i := 0; i < 3; i++ { //TODO: workaround if org projection is not yet up-to-date - time.Sleep(500 * time.Millisecond) - err := t.ExistsOrg(ctx, orgID) - if err == nil { - break - } - } + err = retry(func() error { + return t.ExistsOrg(ctx, orgID) + }) if err != nil { return CtxData{}, errors.ThrowPermissionDenied(nil, "AUTH-Bs7Ds", "Organisation doesn't exist") } diff --git a/internal/api/authz/permissions.go b/internal/api/authz/permissions.go index c91eeaef55..64228f68b3 100644 --- a/internal/api/authz/permissions.go +++ b/internal/api/authz/permissions.go @@ -2,6 +2,7 @@ package authz import ( "context" + "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/telemetry/tracing" ) @@ -20,7 +21,19 @@ func getUserMethodPermissions(ctx context.Context, t *TokenVerifier, requiredPer return nil, nil, err } if len(memberships) == 0 { - return requestedPermissions, nil, nil + err = retry(func() error { + memberships, err = t.SearchMyMemberships(ctx) + if err != nil { + return err + } + if len(memberships) == 0 { + return errors.ThrowNotFound(nil, "AUTHZ-cdgFk", "membership not found") + } + return nil + }) + if err != nil { + return nil, nil, nil + } } requestedPermissions, allPermissions = mapMembershipsToPermissions(requiredPerm, memberships, authConfig) return requestedPermissions, allPermissions, nil diff --git a/internal/api/authz/retry.go b/internal/api/authz/retry.go new file mode 100644 index 0000000000..b22cb542fc --- /dev/null +++ b/internal/api/authz/retry.go @@ -0,0 +1,15 @@ +package authz + +import "time" + +//TODO: workaround if org projection is not yet up-to-date +func retry(retriable func() error) (err error) { + for i := 0; i < 3; i++ { + time.Sleep(500 * time.Millisecond) + err = retriable() + if err == nil { + return nil + } + } + return err +} diff --git a/internal/api/authz/retry_test.go b/internal/api/authz/retry_test.go new file mode 100644 index 0000000000..685509b955 --- /dev/null +++ b/internal/api/authz/retry_test.go @@ -0,0 +1,88 @@ +package authz + +import ( + "errors" + "testing" +) + +func Test_retry(t *testing.T) { + type args struct { + retriable func(*int) func() error + } + type want struct { + executions int + err bool + } + tests := []struct { + name string + args args + want want + }{ + { + name: "1 execution", + args: args{ + retriable: func(execs *int) func() error { + return func() error { + if *execs < 1 { + *execs++ + return errors.New("not 1") + } + return nil + } + }, + }, + want: want{ + err: false, + executions: 1, + }, + }, + { + name: "2 execution", + args: args{ + retriable: func(execs *int) func() error { + return func() error { + if *execs < 2 { + *execs++ + return errors.New("not 2") + } + return nil + } + }, + }, + want: want{ + err: false, + executions: 2, + }, + }, + { + name: "too many execution", + args: args{ + retriable: func(execs *int) func() error { + return func() error { + if *execs < 3 { + *execs++ + return errors.New("not 3") + } + return nil + } + }, + }, + want: want{ + err: true, + executions: 3, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var execs int + + if err := retry(tt.args.retriable(&execs)); (err != nil) != tt.want.err { + t.Errorf("retry() error = %v, want.err %v", err, tt.want.err) + } + if execs != tt.want.executions { + t.Errorf("retry() executions: want: %d got: %d", tt.want.executions, execs) + } + }) + } +}