mirror of
				https://github.com/zitadel/zitadel.git
				synced 2025-10-25 20:38:48 +00:00 
			
		
		
		
	userinfo and project roles in go routines
This commit is contained in:
		| @@ -370,7 +370,7 @@ func (o *OPStorage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, us | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	o.setUserInfoRoleClaims(userInfo, projectRoles) | ||||
| 	setUserInfoRoleClaims(userInfo, projectRoles) | ||||
|  | ||||
| 	return o.userinfoFlows(ctx, user, userGrants, userInfo) | ||||
| } | ||||
| @@ -432,7 +432,7 @@ func (o *OPStorage) setUserInfoResourceOwner(ctx context.Context, userInfo *oidc | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (o *OPStorage) setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) { | ||||
| func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) { | ||||
| 	if roles != nil && len(roles.projects) > 0 { | ||||
| 		if roles, ok := roles.projects[roles.requestProjectID]; ok { | ||||
| 			userInfo.AppendClaims(ClaimProjectRoles, roles) | ||||
|   | ||||
| @@ -82,7 +82,7 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR | ||||
| 	if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	userInfo, err := s.storage.query.GetOIDCUserinfo(ctx, token.userID, token.scope, []string{client.projectID}) | ||||
| 	userInfo, err := s.getUserInfoWithRoles(ctx, token.userID, client.projectID, token.scope, []string{client.projectID}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -98,7 +98,7 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR | ||||
| 		Issuer:     op.IssuerFromContext(ctx), | ||||
| 		JWTID:      token.tokenID, | ||||
| 	} | ||||
| 	introspectionResp.SetUserInfo(userinfoToOIDC(userInfo, token.scope)) | ||||
| 	introspectionResp.SetUserInfo(userInfo) | ||||
| 	return op.NewResponse(introspectionResp), nil | ||||
| } | ||||
|  | ||||
| @@ -224,7 +224,7 @@ func introspectionTokenResultV1(tokenID, subject string, token *model.TokenView) | ||||
| 		tokenID:         tokenID, | ||||
| 		userID:          token.UserID, | ||||
| 		subject:         subject, | ||||
| 		clientID:        token.ApplicationID, // check correctness? | ||||
| 		clientID:        token.ApplicationID, | ||||
| 		audience:        token.Audience, | ||||
| 		scope:           token.Scopes, | ||||
| 		tokenCreation:   token.CreationDate, | ||||
|   | ||||
| @@ -122,8 +122,9 @@ func NewServer( | ||||
| 	server := &Server{ | ||||
| 		storage:        storage, | ||||
| 		LegacyServer:   op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), | ||||
| 		hashAlg:        crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant. | ||||
| 		query:          query, | ||||
| 		fallbackLogger: fallbackLogger, | ||||
| 		hashAlg:        crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant. | ||||
| 	} | ||||
| 	metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} | ||||
| 	server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware( | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"github.com/zitadel/oidc/v3/pkg/oidc" | ||||
| 	"github.com/zitadel/oidc/v3/pkg/op" | ||||
| 	"github.com/zitadel/zitadel/internal/crypto" | ||||
| 	"github.com/zitadel/zitadel/internal/query" | ||||
| 	"github.com/zitadel/zitadel/internal/telemetry/tracing" | ||||
| ) | ||||
|  | ||||
| @@ -18,6 +19,8 @@ type Server struct { | ||||
| 	storage *OPStorage | ||||
| 	*op.LegacyServer | ||||
|  | ||||
| 	query *query.Queries | ||||
|  | ||||
| 	fallbackLogger *slog.Logger | ||||
| 	hashAlg        crypto.HashAlgorithm | ||||
| } | ||||
|   | ||||
| @@ -1,17 +1,152 @@ | ||||
| package oidc | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"slices" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/zitadel/oidc/v3/pkg/oidc" | ||||
|  | ||||
| 	"github.com/zitadel/zitadel/internal/domain" | ||||
| 	"github.com/zitadel/zitadel/internal/query" | ||||
| 	"github.com/zitadel/zitadel/internal/telemetry/tracing" | ||||
| ) | ||||
|  | ||||
| func userinfoToOIDC(user *query.OIDCUserinfo, scopes []string) *oidc.UserInfo { | ||||
| func (s *Server) getUserInfoWithRoles(ctx context.Context, userID, projectID string, scope, roleAudience []string) (_ *oidc.UserInfo, err error) { | ||||
| 	ctx, span := tracing.NewSpan(ctx) | ||||
| 	defer func() { span.EndWithError(err) }() | ||||
| 	ctx, cancel := context.WithCancel(ctx) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	userInfoChan := make(chan *userInfoResult) | ||||
| 	go s.getUserInfo(ctx, userID, scope, roleAudience, userInfoChan) | ||||
|  | ||||
| 	rolesChan := make(chan *assertRolesResult) | ||||
| 	go s.assertRoles(ctx, userID, projectID, scope, roleAudience, rolesChan) | ||||
|  | ||||
| 	var ( | ||||
| 		userInfoResult    *userInfoResult | ||||
| 		assertRolesResult *assertRolesResult | ||||
| 	) | ||||
|  | ||||
| 	// make sure both channels are always read, | ||||
| 	// and cancel the context on first error | ||||
| 	for i := 0; i < 2; i++ { | ||||
| 		var resErr error | ||||
|  | ||||
| 		select { | ||||
| 		case userInfoResult = <-userInfoChan: | ||||
| 			resErr = userInfoResult.err | ||||
| 		case assertRolesResult = <-rolesChan: | ||||
| 			resErr = assertRolesResult.err | ||||
| 		} | ||||
|  | ||||
| 		if resErr == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		cancel() | ||||
|  | ||||
| 		// we only care for the first error that occured, | ||||
| 		// as the next error is most probably a context error. | ||||
| 		if err == nil { | ||||
| 			err = resErr | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	userInfo := userInfoToOIDC(userInfoResult.userInfo, scope) | ||||
| 	setUserInfoRoleClaims(userInfo, assertRolesResult.projectsRoles) | ||||
|  | ||||
| 	return userInfo, nil | ||||
| } | ||||
|  | ||||
| type userInfoResult struct { | ||||
| 	userInfo *query.OIDCUserInfo | ||||
| 	err      error | ||||
| } | ||||
|  | ||||
| func (s *Server) getUserInfo(ctx context.Context, userID string, scope, roleAudience []string, rc chan<- *userInfoResult) { | ||||
| 	userInfo, err := s.storage.query.GetOIDCUserInfo(ctx, userID, scope, roleAudience) | ||||
| 	rc <- &userInfoResult{ | ||||
| 		userInfo: userInfo, | ||||
| 		err:      err, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type assertRolesResult struct { | ||||
| 	userGrants    *query.UserGrants | ||||
| 	projectsRoles *projectsRoles | ||||
| 	err           error | ||||
| } | ||||
|  | ||||
| func (s *Server) assertRoles(ctx context.Context, userID, projectID string, scope, roleAudience []string, rc chan<- *assertRolesResult) { | ||||
| 	userGrands, projectsRoles, err := func() (*query.UserGrants, *projectsRoles, error) { | ||||
| 		// if all roles are requested take the audience for those from the scopes | ||||
| 		if slices.Contains(scope, domain.ScopeProjectsRoles) { | ||||
| 			roleAudience = domain.AddAudScopeToAudience(ctx, roleAudience, scope) | ||||
| 		} | ||||
|  | ||||
| 		requestedRoles := make([]string, 0, len(scope)) | ||||
| 		for _, s := range scope { | ||||
| 			if role, ok := strings.CutPrefix(s, ScopeProjectRolePrefix); ok { | ||||
| 				requestedRoles = append(requestedRoles, role) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if len(requestedRoles) == 0 && len(roleAudience) == 0 { | ||||
| 			return nil, nil, nil | ||||
| 		} | ||||
|  | ||||
| 		// ensure the projectID of the requesting is part of the roleAudience | ||||
| 		if projectID != "" { | ||||
| 			roleAudience = append(roleAudience, projectID) | ||||
| 		} | ||||
| 		queries := make([]query.SearchQuery, 0, 2) | ||||
| 		projectQuery, err := query.NewUserGrantProjectIDsSearchQuery(roleAudience) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 		} | ||||
| 		queries = append(queries, projectQuery) | ||||
| 		userIDQuery, err := query.NewUserGrantUserIDSearchQuery(userID) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 		} | ||||
| 		queries = append(queries, userIDQuery) | ||||
| 		grants, err := s.query.UserGrants(ctx, &query.UserGrantsQueries{ | ||||
| 			Queries: queries, | ||||
| 		}, false, false) // triggers disabled | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 		} | ||||
| 		roles := new(projectsRoles) | ||||
| 		// if specific roles where requested, check if they are granted and append them in the roles list | ||||
| 		if len(requestedRoles) > 0 { | ||||
| 			for _, requestedRole := range requestedRoles { | ||||
| 				for _, grant := range grants.UserGrants { | ||||
| 					checkGrantedRoles(roles, grant, requestedRole, grant.ProjectID == projectID) | ||||
| 				} | ||||
| 			} | ||||
| 			return grants, roles, nil | ||||
| 		} | ||||
| 		// no specific roles were requested, so convert any grants into roles | ||||
| 		for _, grant := range grants.UserGrants { | ||||
| 			for _, role := range grant.Roles { | ||||
| 				roles.Add(grant.ProjectID, role, grant.ResourceOwner, grant.OrgPrimaryDomain, grant.ProjectID == projectID) | ||||
| 			} | ||||
| 		} | ||||
| 		return grants, roles, nil | ||||
| 	}() | ||||
|  | ||||
| 	rc <- &assertRolesResult{ | ||||
| 		userGrants:    userGrands, | ||||
| 		projectsRoles: projectsRoles, | ||||
| 		err:           err, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func userInfoToOIDC(user *query.OIDCUserInfo, scope []string) *oidc.UserInfo { | ||||
| 	out := new(oidc.UserInfo) | ||||
| 	for _, scope := range scopes { | ||||
| 		switch scope { | ||||
| 	for _, s := range scope { | ||||
| 		switch s { | ||||
| 		case oidc.ScopeOpenID: | ||||
| 			out.Subject = user.ID | ||||
| 		case oidc.ScopeEmail: | ||||
| @@ -29,11 +164,11 @@ func userinfoToOIDC(user *query.OIDCUserinfo, scopes []string) *oidc.UserInfo { | ||||
| 		case ScopeResourceOwner: | ||||
| 			setUserInfoOrgClaims(user, out) | ||||
| 		default: | ||||
| 			if strings.HasPrefix(scope, domain.OrgDomainPrimaryScope) { | ||||
| 				out.AppendClaims(domain.OrgDomainPrimaryClaim, strings.TrimPrefix(scope, domain.OrgDomainPrimaryScope)) | ||||
| 			if strings.HasPrefix(s, domain.OrgDomainPrimaryScope) { | ||||
| 				out.AppendClaims(domain.OrgDomainPrimaryClaim, strings.TrimPrefix(s, domain.OrgDomainPrimaryScope)) | ||||
| 			} | ||||
| 			if strings.HasPrefix(scope, domain.OrgIDScope) { | ||||
| 				out.AppendClaims(domain.OrgIDClaim, strings.TrimPrefix(scope, domain.OrgIDScope)) | ||||
| 			if strings.HasPrefix(s, domain.OrgIDScope) { | ||||
| 				out.AppendClaims(domain.OrgIDClaim, strings.TrimPrefix(s, domain.OrgIDScope)) | ||||
| 				setUserInfoOrgClaims(user, out) | ||||
| 			} | ||||
| 		} | ||||
| @@ -42,14 +177,14 @@ func userinfoToOIDC(user *query.OIDCUserinfo, scopes []string) *oidc.UserInfo { | ||||
| 	return out | ||||
| } | ||||
|  | ||||
| func userInfoEmailToOIDC(user *query.OIDCUserinfo) oidc.UserInfoEmail { | ||||
| func userInfoEmailToOIDC(user *query.OIDCUserInfo) oidc.UserInfoEmail { | ||||
| 	return oidc.UserInfoEmail{ | ||||
| 		Email:         string(user.Email), | ||||
| 		EmailVerified: oidc.Bool(user.IsEmailVerified), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func userInfoProfileToOidc(user *query.OIDCUserinfo) oidc.UserInfoProfile { | ||||
| func userInfoProfileToOidc(user *query.OIDCUserInfo) oidc.UserInfoProfile { | ||||
| 	return oidc.UserInfoProfile{ | ||||
| 		Name:       user.Name, | ||||
| 		GivenName:  user.FirstName, | ||||
| @@ -63,14 +198,14 @@ func userInfoProfileToOidc(user *query.OIDCUserinfo) oidc.UserInfoProfile { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func userInfoPhoneToOIDC(user *query.OIDCUserinfo) oidc.UserInfoPhone { | ||||
| func userInfoPhoneToOIDC(user *query.OIDCUserInfo) oidc.UserInfoPhone { | ||||
| 	return oidc.UserInfoPhone{ | ||||
| 		PhoneNumber:         string(user.Phone), | ||||
| 		PhoneNumberVerified: user.IsPhoneVerified, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func userInfoAddressToOIDC(user *query.OIDCUserinfo) *oidc.UserInfoAddress { | ||||
| func userInfoAddressToOIDC(user *query.OIDCUserInfo) *oidc.UserInfoAddress { | ||||
| 	return &oidc.UserInfoAddress{ | ||||
| 		// Formatted: ??, | ||||
| 		StreetAddress: user.StreetAddress, | ||||
| @@ -81,7 +216,7 @@ func userInfoAddressToOIDC(user *query.OIDCUserinfo) *oidc.UserInfoAddress { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func setUserInfoOrgClaims(user *query.OIDCUserinfo, out *oidc.UserInfo) { | ||||
| func setUserInfoOrgClaims(user *query.OIDCUserInfo, out *oidc.UserInfo) { | ||||
| 	out.AppendClaims(ClaimResourceOwner+"id", user.OrgID) | ||||
| 	out.AppendClaims(ClaimResourceOwner+"name", user.OrgName) | ||||
| 	out.AppendClaims(ClaimResourceOwner+"primary_domain", user.OrgPrimaryDomain) | ||||
|   | ||||
| @@ -15,7 +15,7 @@ import ( | ||||
| 	"golang.org/x/text/language" | ||||
| ) | ||||
|  | ||||
| func (q *Queries) GetOIDCUserinfo(ctx context.Context, userID string, scope, roleAudience []string) (_ *OIDCUserinfo, err error) { | ||||
| func (q *Queries) GetOIDCUserInfo(ctx context.Context, userID string, scope, roleAudience []string) (_ *OIDCUserInfo, err error) { | ||||
| 	if slices.Contains(scope, domain.ScopeProjectsRoles) { | ||||
| 		roleAudience = domain.AddAudScopeToAudience(ctx, roleAudience, scope) | ||||
| 		// TODO: we need to get the project roles and user roles. | ||||
| @@ -37,7 +37,7 @@ func (q *Queries) GetOIDCUserinfo(ctx context.Context, userID string, scope, rol | ||||
| 		user.OrgPrimaryDomain = org.PrimaryDomain | ||||
| 	} | ||||
|  | ||||
| 	return &user.OIDCUserinfo, nil | ||||
| 	return &user.OIDCUserInfo, nil | ||||
| } | ||||
|  | ||||
| func hasOrgScope(scope []string) bool { | ||||
| @@ -46,7 +46,7 @@ func hasOrgScope(scope []string) bool { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type OIDCUserinfo struct { | ||||
| type OIDCUserInfo struct { | ||||
| 	ID                string | ||||
| 	UserName          string | ||||
| 	Name              string | ||||
| @@ -80,25 +80,25 @@ type OIDCUserinfo struct { | ||||
| 	Metadata map[string]string | ||||
| } | ||||
|  | ||||
| type oidcUserinfoReadmodel struct { | ||||
| type oidcUserInfoReadmodel struct { | ||||
| 	eventstore.ReadModel | ||||
| 	scope []string // Scope is used to determine events | ||||
| 	OIDCUserinfo | ||||
| 	OIDCUserInfo | ||||
| } | ||||
|  | ||||
| func newOidcUserinfoReadModel(userID string, scope []string) *oidcUserinfoReadmodel { | ||||
| 	return &oidcUserinfoReadmodel{ | ||||
| func newOidcUserinfoReadModel(userID string, scope []string) *oidcUserInfoReadmodel { | ||||
| 	return &oidcUserInfoReadmodel{ | ||||
| 		ReadModel: eventstore.ReadModel{ | ||||
| 			AggregateID: userID, | ||||
| 		}, | ||||
| 		scope: scope, | ||||
| 		OIDCUserinfo: OIDCUserinfo{ | ||||
| 		OIDCUserInfo: OIDCUserInfo{ | ||||
| 			ID: userID, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (rm *oidcUserinfoReadmodel) Query() *eventstore.SearchQueryBuilder { | ||||
| func (rm *oidcUserInfoReadmodel) Query() *eventstore.SearchQueryBuilder { | ||||
| 	return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). | ||||
| 		AwaitOpenTransactions(). | ||||
| 		AllowTimeTravel(). | ||||
| @@ -112,7 +112,7 @@ func (rm *oidcUserinfoReadmodel) Query() *eventstore.SearchQueryBuilder { | ||||
| // scopeToEventTypes sets required user events to obtain get the correct userinfo. | ||||
| // Events such as UserLocked, UserDeactivated and UserRemoved are not checked, | ||||
| // as access tokens should already be revoked. | ||||
| func (rm *oidcUserinfoReadmodel) scopeToEventTypes() []eventstore.EventType { | ||||
| func (rm *oidcUserInfoReadmodel) scopeToEventTypes() []eventstore.EventType { | ||||
| 	types := make([]eventstore.EventType, 0, len(rm.scope)) | ||||
| 	types = append(types, user.HumanAddedType, user.MachineAddedEventType) | ||||
|  | ||||
| @@ -133,7 +133,7 @@ func (rm *oidcUserinfoReadmodel) scopeToEventTypes() []eventstore.EventType { | ||||
| 	return slices.Compact(types) | ||||
| } | ||||
|  | ||||
| func (rm *oidcUserinfoReadmodel) Reduce() error { | ||||
| func (rm *oidcUserInfoReadmodel) Reduce() error { | ||||
| 	for _, event := range rm.Events { | ||||
| 		switch e := event.(type) { | ||||
| 		case *user.HumanAddedEvent: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Tim Möhlmann
					Tim Möhlmann