diff --git a/internal/api/authz/authorization.go b/internal/api/authz/authorization.go index 7ae5e838f4..36b17e5cd8 100644 --- a/internal/api/authz/authorization.go +++ b/internal/api/authz/authorization.go @@ -14,33 +14,39 @@ const ( authenticated = "authenticated" ) -func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID string, verifier *TokenVerifier, authConfig Config, requiredAuthOption Option, method string) (_ context.Context, err error) { +func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID string, verifier *TokenVerifier, authConfig Config, requiredAuthOption Option, method string) (ctxSetter func(context.Context) context.Context, err error) { ctx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() - ctx, err = VerifyTokenAndWriteCtxData(ctx, token, orgID, verifier, method) + ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, verifier, method) if err != nil { return nil, err } - var perms []string if requiredAuthOption.Permission == authenticated { - return ctx, nil + return func(parent context.Context) context.Context { + return context.WithValue(parent, dataKey, ctxData) + }, nil } - ctx, perms, err = getUserMethodPermissions(ctx, verifier, requiredAuthOption.Permission, authConfig) + requestedPermissions, allPermissions, err := getUserMethodPermissions(ctx, verifier, requiredAuthOption.Permission, authConfig, ctxData) if err != nil { return nil, err } ctx, userPermissionSpan := tracing.NewNamedSpan(ctx, "checkUserPermissions") - err = checkUserPermissions(req, perms, requiredAuthOption) + err = checkUserPermissions(req, requestedPermissions, requiredAuthOption) userPermissionSpan.EndWithError(err) if err != nil { return nil, err } - return ctx, nil + return func(parent context.Context) context.Context { + parent = context.WithValue(parent, dataKey, ctxData) + parent = context.WithValue(parent, allPermissionsKey, allPermissions) + parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions) + return parent + }, nil } func checkUserPermissions(req interface{}, userPerms []string, authOpt Option) error { diff --git a/internal/api/authz/context.go b/internal/api/authz/context.go index 4880a7373e..374ef0c571 100644 --- a/internal/api/authz/context.go +++ b/internal/api/authz/context.go @@ -36,29 +36,36 @@ type Grant struct { Roles []string } -func VerifyTokenAndWriteCtxData(ctx context.Context, token, orgID string, t *TokenVerifier, method string) (_ context.Context, err error) { +func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID string, t *TokenVerifier, method string) (_ CtxData, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() if orgID != "" { err = t.ExistsOrg(ctx, orgID) if err != nil { - return nil, errors.ThrowPermissionDenied(nil, "AUTH-Bs7Ds", "Organisation doesn't exist") + return CtxData{}, errors.ThrowPermissionDenied(nil, "AUTH-Bs7Ds", "Organisation doesn't exist") } } userID, clientID, agentID, prefLang, err := verifyAccessToken(ctx, token, t, method) if err != nil { - return nil, err + return CtxData{}, err } projectID, origins, err := t.ProjectIDAndOriginsByClientID(ctx, clientID) if err != nil { - return nil, errors.ThrowPermissionDenied(err, "AUTH-GHpw2", "could not read projectid by clientid") + return CtxData{}, errors.ThrowPermissionDenied(err, "AUTH-GHpw2", "could not read projectid by clientid") } if err := checkOrigin(ctx, origins); err != nil { - return nil, err + return CtxData{}, err } - return context.WithValue(ctx, dataKey, CtxData{UserID: userID, OrgID: orgID, ProjectID: projectID, AgentID: agentID, PreferredLanguage: prefLang}), nil + return CtxData{ + UserID: userID, + OrgID: orgID, + ProjectID: projectID, + AgentID: agentID, + PreferredLanguage: prefLang, + }, nil + } func SetCtxData(ctx context.Context, ctxData CtxData) context.Context { diff --git a/internal/api/authz/permissions.go b/internal/api/authz/permissions.go index 8023152ca9..6423f80fd7 100644 --- a/internal/api/authz/permissions.go +++ b/internal/api/authz/permissions.go @@ -7,29 +7,29 @@ import ( "github.com/caos/zitadel/internal/telemetry/tracing" ) -func getUserMethodPermissions(ctx context.Context, t *TokenVerifier, requiredPerm string, authConfig Config) (_ context.Context, _ []string, err error) { +func getUserMethodPermissions(ctx context.Context, t *TokenVerifier, requiredPerm string, authConfig Config, ctxData CtxData) (requestedPermissions, allPermissions []string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - ctxData := GetCtxData(ctx) if ctxData.IsZero() { return nil, nil, errors.ThrowUnauthenticated(nil, "AUTH-rKLWEH", "context missing") } + + ctx = context.WithValue(ctx, dataKey, ctxData) grant, err := t.ResolveGrant(ctx) if err != nil { return nil, nil, err } if grant == nil { - return context.WithValue(ctx, requestPermissionsKey, []string{}), []string{}, nil + return requestedPermissions, nil, nil } - requestPermissions, allPermissions := mapGrantToPermissions(requiredPerm, grant, authConfig) - ctx = context.WithValue(ctx, allPermissionsKey, allPermissions) - return context.WithValue(ctx, requestPermissionsKey, requestPermissions), requestPermissions, nil + requestedPermissions, allPermissions = mapGrantToPermissions(requiredPerm, grant, authConfig) + return requestedPermissions, allPermissions, nil } -func mapGrantToPermissions(requiredPerm string, grant *Grant, authConfig Config) ([]string, []string) { - requestPermissions := make([]string, 0) - allPermissions := make([]string, 0) +func mapGrantToPermissions(requiredPerm string, grant *Grant, authConfig Config) (requestPermissions, allPermissions []string) { + requestPermissions = make([]string, 0) + allPermissions = make([]string, 0) for _, role := range grant.Roles { requestPermissions, allPermissions = mapRoleToPerm(requiredPerm, role, authConfig, requestPermissions, allPermissions) } diff --git a/internal/api/authz/permissions_test.go b/internal/api/authz/permissions_test.go index 92e4407706..69877fd4cb 100644 --- a/internal/api/authz/permissions_test.go +++ b/internal/api/authz/permissions_test.go @@ -49,7 +49,7 @@ func equalStringArray(a, b []string) bool { func Test_GetUserMethodPermissions(t *testing.T) { type args struct { - ctx context.Context + ctxData CtxData verifier *TokenVerifier requiredPerm string authConfig Config @@ -64,7 +64,7 @@ func Test_GetUserMethodPermissions(t *testing.T) { { name: "Empty Context", args: args{ - ctx: getTestCtx("", ""), + ctxData: CtxData{}, verifier: Start(&testVerifier{grant: &Grant{ Roles: []string{"ORG_OWNER"}, }}), @@ -89,7 +89,7 @@ func Test_GetUserMethodPermissions(t *testing.T) { { name: "No Grants", args: args{ - ctx: getTestCtx("", ""), + ctxData: CtxData{}, verifier: Start(&testVerifier{grant: &Grant{}}), requiredPerm: "project.read", authConfig: Config{ @@ -110,9 +110,9 @@ func Test_GetUserMethodPermissions(t *testing.T) { { name: "Get Permissions", args: args{ - ctx: getTestCtx("userID", "orgID"), + ctxData: CtxData{UserID: "userID", OrgID: "orgID"}, verifier: Start(&testVerifier{grant: &Grant{ - Roles: []string{"ORG_OWNER"}, + Roles: []string{"IAM_OWNER"}, }}), requiredPerm: "project.read", authConfig: Config{ @@ -133,7 +133,7 @@ func Test_GetUserMethodPermissions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, perms, err := getUserMethodPermissions(tt.args.ctx, tt.args.verifier, tt.args.requiredPerm, tt.args.authConfig) + _, perms, err := getUserMethodPermissions(context.Background(), tt.args.verifier, tt.args.requiredPerm, tt.args.authConfig, tt.args.ctxData) if tt.wantErr && err == nil { t.Errorf("got wrong result, should get err: actual: %v ", err) diff --git a/internal/api/grpc/server/middleware/auth_interceptor.go b/internal/api/grpc/server/middleware/auth_interceptor.go index 347e0b18f9..165be2a6f9 100644 --- a/internal/api/grpc/server/middleware/auth_interceptor.go +++ b/internal/api/grpc/server/middleware/auth_interceptor.go @@ -25,20 +25,20 @@ func authorize(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, return handler(ctx, req) } - ctx, span := tracing.NewServerInterceptorSpan(ctx) + authCtx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() - authToken := grpc_util.GetAuthorizationHeader(ctx) + authToken := grpc_util.GetAuthorizationHeader(authCtx) if authToken == "" { return nil, status.Error(codes.Unauthenticated, "auth header missing") } - orgID := grpc_util.GetHeader(ctx, http.ZitadelOrgID) + orgID := grpc_util.GetHeader(authCtx, http.ZitadelOrgID) - ctx, err = authz.CheckUserAuthorization(ctx, req, authToken, orgID, verifier, authConfig, authOpt, info.FullMethod) + ctxSetter, err := authz.CheckUserAuthorization(authCtx, req, authToken, orgID, verifier, authConfig, authOpt, info.FullMethod) if err != nil { return nil, err } span.End() - return handler(ctx, req) + return handler(ctxSetter(ctx), req) }