Merge branch 'main' into grcp-server-reflect

This commit is contained in:
Tim Möhlmann
2023-05-07 16:47:52 +02:00
107 changed files with 5654 additions and 173 deletions

View File

@@ -17,6 +17,7 @@ import (
internal_authz "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server"
http_util "github.com/zitadel/zitadel/internal/api/http"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
@@ -169,6 +170,7 @@ func (a *API) routeGRPCWeb() {
return true
}),
)
a.router.Use(http_mw.RobotsTagHandler)
a.router.NewRoute().
Methods(http.MethodPost, http.MethodOptions).
MatcherFunc(

View File

@@ -14,11 +14,16 @@ const (
authenticated = "authenticated"
)
func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgIDHeader string, verifier *TokenVerifier, authConfig Config, requiredAuthOption Option, method string) (ctxSetter func(context.Context) context.Context, err error) {
// CheckUserAuthorization verifies that:
// - the token is active,
// - the organisation (**either** provided by ID or verified domain) exists
// - the user is permitted to call the requested endpoint (permission option in proto)
// it will pass the [CtxData] and permission of the user into the ctx [context.Context]
func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID, orgDomain string, verifier *TokenVerifier, authConfig Config, requiredAuthOption Option, method string) (ctxSetter func(context.Context) context.Context, err error) {
ctx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgIDHeader, verifier, method)
ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier, method)
if err != nil {
return nil, err
}

View File

@@ -60,7 +60,7 @@ const (
MemberTypeIam
)
func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID string, t *TokenVerifier, method string) (_ CtxData, err error) {
func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t *TokenVerifier, method string) (_ CtxData, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@@ -82,14 +82,15 @@ func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID string, t *To
if err := checkOrigin(ctx, origins); err != nil {
return CtxData{}, err
}
if orgID == "" {
if orgID == "" && orgDomain == "" {
orgID = resourceOwner
}
err = t.ExistsOrg(ctx, orgID)
verifiedOrgID, err := t.ExistsOrg(ctx, orgID, orgDomain)
if err != nil {
err = retry(func() error {
return t.ExistsOrg(ctx, orgID)
verifiedOrgID, err = t.ExistsOrg(ctx, orgID, orgDomain)
return err
})
if err != nil {
return CtxData{}, errors.ThrowPermissionDenied(nil, "AUTH-Bs7Ds", "Organisation doesn't exist")
@@ -98,7 +99,7 @@ func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID string, t *To
return CtxData{
UserID: userID,
OrgID: orgID,
OrgID: verifiedOrgID,
ProjectID: projectID,
AgentID: agentID,
PreferredLanguage: prefLang,

View File

@@ -7,12 +7,8 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func CheckPermission(ctx context.Context, resolver MembershipsResolver, roleMappings []RoleMapping, permission, orgID, resourceID string, allowSelf bool) (err error) {
ctxData := GetCtxData(ctx)
if allowSelf && ctxData.UserID == resourceID {
return nil
}
requestedPermissions, _, err := getUserPermissions(ctx, resolver, permission, roleMappings, ctxData, orgID)
func CheckPermission(ctx context.Context, resolver MembershipsResolver, roleMappings []RoleMapping, permission, orgID, resourceID string) (err error) {
requestedPermissions, _, err := getUserPermissions(ctx, resolver, permission, roleMappings, GetCtxData(ctx), orgID)
if err != nil {
return err
}

View File

@@ -26,8 +26,8 @@ func (v *testVerifier) ProjectIDAndOriginsByClientID(ctx context.Context, client
return "", nil, nil
}
func (v *testVerifier) ExistsOrg(ctx context.Context, orgID string) error {
return nil
func (v *testVerifier) ExistsOrg(ctx context.Context, orgID, domain string) (string, error) {
return orgID, nil
}
func (v *testVerifier) VerifierClientID(ctx context.Context, appName string) (string, string, error) {

View File

@@ -3,6 +3,8 @@ package authz
import (
"context"
"crypto/rsa"
"encoding/base64"
"fmt"
"os"
"strings"
"sync"
@@ -17,7 +19,8 @@ import (
)
const (
BearerPrefix = "Bearer "
BearerPrefix = "Bearer "
SessionTokenFormat = "sess_%s:%s"
)
type TokenVerifier struct {
@@ -36,7 +39,7 @@ type authZRepo interface {
VerifierClientID(ctx context.Context, name string) (clientID, projectID string, err error)
SearchMyMemberships(ctx context.Context, orgID string) ([]*Membership, error)
ProjectIDAndOriginsByClientID(ctx context.Context, clientID string) (projectID string, origins []string, err error)
ExistsOrg(ctx context.Context, orgID string) error
ExistsOrg(ctx context.Context, id, domain string) (string, error)
}
func Start(authZRepo authZRepo, issuer string, keys map[string]*SystemAPIUser) (v *TokenVerifier) {
@@ -144,10 +147,10 @@ func (v *TokenVerifier) ProjectIDAndOriginsByClientID(ctx context.Context, clien
return v.authZRepo.ProjectIDAndOriginsByClientID(ctx, clientID)
}
func (v *TokenVerifier) ExistsOrg(ctx context.Context, orgID string) (err error) {
func (v *TokenVerifier) ExistsOrg(ctx context.Context, id, domain string) (orgID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return v.authZRepo.ExistsOrg(ctx, orgID)
return v.authZRepo.ExistsOrg(ctx, id, domain)
}
func (v *TokenVerifier) CheckAuthMethod(method string) (Option, bool) {
@@ -165,3 +168,20 @@ func verifyAccessToken(ctx context.Context, token string, t *TokenVerifier, meth
}
return t.VerifyAccessToken(ctx, parts[1], method)
}
func SessionTokenVerifier(algorithm crypto.EncryptionAlgorithm) func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
decodedToken, err := base64.RawURLEncoding.DecodeString(sessionToken)
if err != nil {
return err
}
_, spanPasswordComparison := tracing.NewNamedSpan(ctx, "crypto.CompareHash")
var token string
token, err = algorithm.DecryptString(decodedToken, algorithm.EncryptionKeyID())
spanPasswordComparison.EndWithError(err)
if err != nil || token != fmt.Sprintf(SessionTokenFormat, sessionID, tokenID) {
return caos_errs.ThrowPermissionDenied(err, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
}
return nil
}
}

View File

@@ -0,0 +1,36 @@
//go:build integration
package admin_test
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/admin"
)
var (
Tester *integration.Tester
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, _, cancel := integration.Contexts(time.Minute)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
return m.Run()
}())
}
func TestServer_Healthz(t *testing.T) {
client := admin.NewAdminServiceClient(Tester.GRPCClientConn)
_, err := client.Healthz(context.TODO(), &admin.HealthzRequest{})
require.NoError(t, err)
}

View File

@@ -4,6 +4,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
)
@@ -17,3 +18,21 @@ func DomainToDetailsPb(objectDetail *domain.ObjectDetails) *object.Details {
}
return details
}
func ToListDetails(response query.SearchResponse) *object.ListDetails {
details := &object.ListDetails{
TotalResult: response.Count,
ProcessedSequence: response.Sequence,
}
if !response.Timestamp.IsZero() {
details.Timestamp = timestamppb.New(response.Timestamp)
}
return details
}
func ListQueryToQuery(query *object.ListQuery) (offset, limit uint64, asc bool) {
if query == nil {
return 0, 0, false
}
return query.Offset, uint64(query.Limit), query.Asc
}

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
@@ -38,6 +39,7 @@ var (
runtime.WithMarshalerOption(runtime.MIMEWildcard, jsonMarshaler),
runtime.WithIncomingHeaderMatcher(headerMatcher),
runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher),
runtime.WithForwardResponseOption(responseForwarder),
}
headerMatcher = runtime.HeaderMatcherFunc(
@@ -50,6 +52,15 @@ var (
return runtime.DefaultHeaderMatcher(header)
},
)
responseForwarder = func(ctx context.Context, w http.ResponseWriter, resp proto.Message) error {
t, ok := resp.(CustomHTTPResponse)
if ok {
// TODO: find a way to return a location header if needed w.Header().Set("location", t.Location())
w.WriteHeader(t.CustomHTTPCode())
}
return nil
}
)
type Gateway struct {
@@ -62,6 +73,10 @@ func (g *Gateway) Handler() http.Handler {
return addInterceptors(g.mux, g.http1HostName)
}
type CustomHTTPResponse interface {
CustomHTTPCode() int
}
type RegisterGatewayFunc func(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error
func CreateGatewayWithPrefix(ctx context.Context, g WithGatewayPrefix, port uint16, http1HostName string) (http.Handler, string, error) {
@@ -134,6 +149,7 @@ func addInterceptors(handler http.Handler, http1HostName string) http.Handler {
handler = http_mw.CallDurationHandler(handler)
handler = http1Host(handler, http1HostName)
handler = http_mw.CORSInterceptor(handler)
handler = http_mw.RobotsTagHandler(handler)
handler = http_mw.DefaultTelemetryHandler(handler)
return http_mw.DefaultMetricsHandler(handler)
}

View File

@@ -11,6 +11,7 @@ import (
grpc_util "github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
)
func AuthorizationInterceptor(verifier *authz.TokenVerifier, authConfig authz.Config) grpc.UnaryServerInterceptor {
@@ -33,12 +34,14 @@ func authorize(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
return nil, status.Error(codes.Unauthenticated, "auth header missing")
}
var orgDomain string
orgID := grpc_util.GetHeader(authCtx, http.ZitadelOrgID)
if o, ok := req.(AuthContext); ok {
orgID = o.AuthContext()
if o, ok := req.(OrganisationFromRequest); ok {
orgID = o.OrganisationFromRequest().GetOrgId()
orgDomain = o.OrganisationFromRequest().GetOrgDomain()
}
ctxSetter, err := authz.CheckUserAuthorization(authCtx, req, authToken, orgID, verifier, authConfig, authOpt, info.FullMethod)
ctxSetter, err := authz.CheckUserAuthorization(authCtx, req, authToken, orgID, orgDomain, verifier, authConfig, authOpt, info.FullMethod)
if err != nil {
return nil, err
}
@@ -46,6 +49,6 @@ func authorize(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
return handler(ctxSetter(ctx), req)
}
type AuthContext interface {
AuthContext() string
type OrganisationFromRequest interface {
OrganisationFromRequest() *object.Organisation
}

View File

@@ -31,8 +31,8 @@ func (v *verifierMock) SearchMyMemberships(ctx context.Context, orgID string) ([
func (v *verifierMock) ProjectIDAndOriginsByClientID(ctx context.Context, clientID string) (string, []string, error) {
return "", nil, nil
}
func (v *verifierMock) ExistsOrg(ctx context.Context, orgID string) error {
return nil
func (v *verifierMock) ExistsOrg(ctx context.Context, orgID, domain string) (string, error) {
return orgID, nil
}
func (v *verifierMock) VerifierClientID(ctx context.Context, appName string) (string, string, error) {
return "", "", nil

View File

@@ -6,16 +6,18 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
)
var _ session.SessionServiceServer = (*Server)(nil)
type Server struct {
session.UnimplementedSessionServiceServer
command *command.Commands
query *query.Queries
command *command.Commands
query *query.Queries
checkPermission domain.PermissionCheck
}
type Config struct{}
@@ -23,10 +25,12 @@ type Config struct{}
func CreateServer(
command *command.Commands,
query *query.Queries,
checkPermission domain.PermissionCheck,
) *Server {
return &Server{
command: command,
query: query,
command: command,
query: query,
checkPermission: checkPermission,
}
}

View File

@@ -3,16 +3,260 @@ package session
import (
"context"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
"github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/command"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
)
func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) {
res, err := s.query.SessionByID(ctx, req.GetSessionId(), req.GetSessionToken())
if err != nil {
return nil, err
}
return &session.GetSessionResponse{
Session: &session.Session{
Id: req.Id,
User: &user.User{Id: authz.GetCtxData(ctx).UserID},
},
Session: sessionToPb(res),
}, nil
}
func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) {
queries, err := listSessionsRequestToQuery(ctx, req)
if err != nil {
return nil, err
}
sessions, err := s.query.SearchSessions(ctx, queries)
if err != nil {
return nil, err
}
return &session.ListSessionsResponse{
Details: object.ToListDetails(sessions.SearchResponse),
Sessions: sessionsToPb(sessions.Sessions),
}, nil
}
func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) {
checks, metadata, err := s.createSessionRequestToCommand(ctx, req)
if err != nil {
return nil, err
}
set, err := s.command.CreateSession(ctx, checks, metadata)
if err != nil {
return nil, err
}
return &session.CreateSessionResponse{
Details: object.DomainToDetailsPb(set.ObjectDetails),
SessionId: set.ID,
SessionToken: set.NewToken,
}, nil
}
func (s *Server) SetSession(ctx context.Context, req *session.SetSessionRequest) (*session.SetSessionResponse, error) {
checks, err := s.setSessionRequestToCommand(ctx, req)
if err != nil {
return nil, err
}
set, err := s.command.UpdateSession(ctx, req.GetSessionId(), req.GetSessionToken(), checks, req.GetMetadata())
if err != nil {
return nil, err
}
// if there's no new token, just return the current
if set.NewToken == "" {
set.NewToken = req.GetSessionToken()
}
return &session.SetSessionResponse{
Details: object.DomainToDetailsPb(set.ObjectDetails),
SessionToken: set.NewToken,
}, nil
}
func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRequest) (*session.DeleteSessionResponse, error) {
details, err := s.command.TerminateSession(ctx, req.GetSessionId(), req.GetSessionToken())
if err != nil {
return nil, err
}
return &session.DeleteSessionResponse{
Details: object.DomainToDetailsPb(details),
}, nil
}
func sessionsToPb(sessions []*query.Session) []*session.Session {
s := make([]*session.Session, len(sessions))
for i, session := range sessions {
s[i] = sessionToPb(session)
}
return s
}
func sessionToPb(s *query.Session) *session.Session {
return &session.Session{
Id: s.ID,
CreationDate: timestamppb.New(s.CreationDate),
ChangeDate: timestamppb.New(s.ChangeDate),
Sequence: s.Sequence,
Factors: factorsToPb(s),
Metadata: s.Metadata,
}
}
func factorsToPb(s *query.Session) *session.Factors {
user := userFactorToPb(s.UserFactor)
pw := passwordFactorToPb(s.PasswordFactor)
if user == nil && pw == nil {
return nil
}
return &session.Factors{
User: user,
Password: pw,
}
}
func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor {
if factor.PasswordCheckedAt.IsZero() {
return nil
}
return &session.PasswordFactor{
VerifiedAt: timestamppb.New(factor.PasswordCheckedAt),
}
}
func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor {
if factor.UserID == "" || factor.UserCheckedAt.IsZero() {
return nil
}
return &session.UserFactor{
VerifiedAt: timestamppb.New(factor.UserCheckedAt),
Id: factor.UserID,
LoginName: factor.LoginName,
DisplayName: factor.DisplayName,
}
}
func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) {
offset, limit, asc := object.ListQueryToQuery(req.Query)
queries, err := sessionQueriesToQuery(ctx, req.GetQueries())
if err != nil {
return nil, err
}
return &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: offset,
Limit: limit,
Asc: asc,
},
Queries: queries,
}, nil
}
func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries)+1)
for i, query := range queries {
q[i], err = sessionQueryToQuery(query)
if err != nil {
return nil, err
}
}
creatorQuery, err := query.NewSessionCreatorSearchQuery(authz.GetCtxData(ctx).UserID)
if err != nil {
return nil, err
}
q[len(queries)] = creatorQuery
return q, nil
}
func sessionQueryToQuery(query *session.SearchQuery) (query.SearchQuery, error) {
switch q := query.Query.(type) {
case *session.SearchQuery_IdsQuery:
return idsQueryToQuery(q.IdsQuery)
default:
return nil, caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid")
}
}
func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) {
return query.NewSessionIDsSearchQuery(q.Ids)
}
func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCheck, map[string][]byte, error) {
checks, err := s.checksToCommand(ctx, req.Checks)
if err != nil {
return nil, nil, err
}
return checks, req.GetMetadata(), nil
}
func (s *Server) setSessionRequestToCommand(ctx context.Context, req *session.SetSessionRequest) ([]command.SessionCheck, error) {
checks, err := s.checksToCommand(ctx, req.Checks)
if err != nil {
return nil, err
}
return checks, nil
}
func (s *Server) checksToCommand(ctx context.Context, checks *session.Checks) ([]command.SessionCheck, error) {
checkUser, err := userCheck(checks.GetUser())
if err != nil {
return nil, err
}
sessionChecks := make([]command.SessionCheck, 0, 2)
if checkUser != nil {
user, err := checkUser.search(ctx, s.query)
if err != nil {
return nil, err
}
sessionChecks = append(sessionChecks, command.CheckUser(user.ID))
}
if password := checks.GetPassword(); password != nil {
sessionChecks = append(sessionChecks, command.CheckPassword(password.GetPassword()))
}
return sessionChecks, nil
}
func userCheck(user *session.CheckUser) (userSearch, error) {
if user == nil {
return nil, nil
}
switch s := user.GetSearch().(type) {
case *session.CheckUser_UserId:
return userByID(s.UserId), nil
case *session.CheckUser_LoginName:
return userByLoginName(s.LoginName)
default:
return nil, caos_errs.ThrowUnimplementedf(nil, "SESSION-d3b4g0", "user search %T not implemented", s)
}
}
type userSearch interface {
search(ctx context.Context, q *query.Queries) (*query.User, error)
}
func userByID(userID string) userSearch {
return userSearchByID{userID}
}
func userByLoginName(loginName string) (userSearch, error) {
loginNameQuery, err := query.NewUserLoginNamesSearchQuery(loginName)
if err != nil {
return nil, err
}
return userSearchByLoginName{loginNameQuery}, nil
}
type userSearchByID struct {
id string
}
func (u userSearchByID) search(ctx context.Context, q *query.Queries) (*query.User, error) {
return q.GetUserByID(ctx, true, u.id, false)
}
type userSearchByLoginName struct {
loginNameQuery query.SearchQuery
}
func (u userSearchByLoginName) search(ctx context.Context, q *query.Queries) (*query.User, error) {
return q.GetUser(ctx, true, false, u.loginNameQuery)
}

View File

@@ -0,0 +1,379 @@
package session
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
)
func Test_sessionsToPb(t *testing.T) {
now := time.Now()
past := now.Add(-time.Hour)
sessions := []*query.Session{
{ // no factor
ID: "999",
CreationDate: now,
ChangeDate: now,
Sequence: 123,
State: domain.SessionStateActive,
ResourceOwner: "me",
Creator: "he",
Metadata: map[string][]byte{"hello": []byte("world")},
},
{ // user factor
ID: "999",
CreationDate: now,
ChangeDate: now,
Sequence: 123,
State: domain.SessionStateActive,
ResourceOwner: "me",
Creator: "he",
UserFactor: query.SessionUserFactor{
UserID: "345",
UserCheckedAt: past,
LoginName: "donald",
DisplayName: "donald duck",
},
Metadata: map[string][]byte{"hello": []byte("world")},
},
{ // no factor
ID: "999",
CreationDate: now,
ChangeDate: now,
Sequence: 123,
State: domain.SessionStateActive,
ResourceOwner: "me",
Creator: "he",
PasswordFactor: query.SessionPasswordFactor{
PasswordCheckedAt: past,
},
Metadata: map[string][]byte{"hello": []byte("world")},
},
}
want := []*session.Session{
{ // no factor
Id: "999",
CreationDate: timestamppb.New(now),
ChangeDate: timestamppb.New(now),
Sequence: 123,
Factors: nil,
Metadata: map[string][]byte{"hello": []byte("world")},
},
{ // user factor
Id: "999",
CreationDate: timestamppb.New(now),
ChangeDate: timestamppb.New(now),
Sequence: 123,
Factors: &session.Factors{
User: &session.UserFactor{
VerifiedAt: timestamppb.New(past),
Id: "345",
LoginName: "donald",
DisplayName: "donald duck",
},
},
Metadata: map[string][]byte{"hello": []byte("world")},
},
{ // password factor
Id: "999",
CreationDate: timestamppb.New(now),
ChangeDate: timestamppb.New(now),
Sequence: 123,
Factors: &session.Factors{
Password: &session.PasswordFactor{
VerifiedAt: timestamppb.New(past),
},
},
Metadata: map[string][]byte{"hello": []byte("world")},
},
}
out := sessionsToPb(sessions)
require.Len(t, out, len(want))
for i, got := range out {
if !proto.Equal(got, want[i]) {
t.Errorf("session %d got:\n%v\nwant:\n%v", i, got, want)
}
}
}
func mustNewTextQuery(t testing.TB, column query.Column, value string, compare query.TextComparison) query.SearchQuery {
q, err := query.NewTextQuery(column, value, compare)
require.NoError(t, err)
return q
}
func mustNewListQuery(t testing.TB, column query.Column, list []any, compare query.ListComparison) query.SearchQuery {
q, err := query.NewListQuery(query.SessionColumnID, list, compare)
require.NoError(t, err)
return q
}
func Test_listSessionsRequestToQuery(t *testing.T) {
type args struct {
ctx context.Context
req *session.ListSessionsRequest
}
tests := []struct {
name string
args args
want *query.SessionsSearchQueries
wantErr error
}{
{
name: "default request",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
req: &session.ListSessionsRequest{},
},
want: &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: 0,
Limit: 0,
Asc: false,
},
Queries: []query.SearchQuery{
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
},
{
name: "with list query and sessions",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
req: &session.ListSessionsRequest{
Query: &object.ListQuery{
Offset: 10,
Limit: 20,
Asc: true,
},
Queries: []*session.SearchQuery{
{Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
Ids: []string{"1", "2", "3"},
},
}},
{Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
Ids: []string{"4", "5", "6"},
},
}},
},
},
},
want: &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: 10,
Limit: 20,
Asc: true,
},
Queries: []query.SearchQuery{
mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
mustNewListQuery(t, query.SessionColumnID, []interface{}{"4", "5", "6"}, query.ListIn),
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
},
{
name: "invalid argument error",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
req: &session.ListSessionsRequest{
Query: &object.ListQuery{
Offset: 10,
Limit: 20,
Asc: true,
},
Queries: []*session.SearchQuery{
{Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
Ids: []string{"1", "2", "3"},
},
}},
{Query: nil},
},
},
},
wantErr: caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := listSessionsRequestToQuery(tt.args.ctx, tt.args.req)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func Test_sessionQueriesToQuery(t *testing.T) {
type args struct {
ctx context.Context
queries []*session.SearchQuery
}
tests := []struct {
name string
args args
want []query.SearchQuery
wantErr error
}{
{
name: "creator only",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
},
want: []query.SearchQuery{
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
{
name: "invalid argument",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
queries: []*session.SearchQuery{
{Query: nil},
},
},
wantErr: caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"),
},
{
name: "creator and sessions",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
queries: []*session.SearchQuery{
{Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
Ids: []string{"1", "2", "3"},
},
}},
{Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
Ids: []string{"4", "5", "6"},
},
}},
},
},
want: []query.SearchQuery{
mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
mustNewListQuery(t, query.SessionColumnID, []interface{}{"4", "5", "6"}, query.ListIn),
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := sessionQueriesToQuery(tt.args.ctx, tt.args.queries)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func Test_sessionQueryToQuery(t *testing.T) {
type args struct {
query *session.SearchQuery
}
tests := []struct {
name string
args args
want query.SearchQuery
wantErr error
}{
{
name: "invalid argument",
args: args{&session.SearchQuery{
Query: nil,
}},
wantErr: caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"),
},
{
name: "query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
Ids: []string{"1", "2", "3"},
},
},
}},
want: mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := sessionQueryToQuery(tt.args.query)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func mustUserLoginNamesSearchQuery(t testing.TB, value string) query.SearchQuery {
loginNameQuery, err := query.NewUserLoginNamesSearchQuery("bar")
require.NoError(t, err)
return loginNameQuery
}
func Test_userCheck(t *testing.T) {
type args struct {
user *session.CheckUser
}
tests := []struct {
name string
args args
want userSearch
wantErr error
}{
{
name: "nil user",
args: args{nil},
want: nil,
},
{
name: "by user id",
args: args{&session.CheckUser{
Search: &session.CheckUser_UserId{
UserId: "foo",
},
}},
want: userSearchByID{"foo"},
},
{
name: "by user id",
args: args{&session.CheckUser{
Search: &session.CheckUser_LoginName{
LoginName: "bar",
},
}},
want: userSearchByLoginName{mustUserLoginNamesSearchQuery(t, "bar")},
},
{
name: "unimplemented error",
args: args{&session.CheckUser{
Search: nil,
}},
wantErr: caos_errs.ThrowUnimplementedf(nil, "SESSION-d3b4g0", "user search %T not implemented", nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := userCheck(tt.args.user)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -21,11 +21,7 @@ func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp
case *user.SetEmailRequest_ReturnCode:
email, err = s.command.ChangeUserEmailReturnCode(ctx, req.GetUserId(), resourceOwner, req.GetEmail(), s.userCodeAlg)
case *user.SetEmailRequest_IsVerified:
if v.IsVerified {
email, err = s.command.ChangeUserEmailVerified(ctx, req.GetUserId(), resourceOwner, req.GetEmail())
} else {
email, err = s.command.ChangeUserEmail(ctx, req.GetUserId(), resourceOwner, req.GetEmail(), s.userCodeAlg)
}
email, err = s.command.ChangeUserEmailVerified(ctx, req.GetUserId(), resourceOwner, req.GetEmail())
case nil:
email, err = s.command.ChangeUserEmail(ctx, req.GetUserId(), resourceOwner, req.GetEmail(), s.userCodeAlg)
default:

View File

@@ -0,0 +1,210 @@
//go:build integration
package user_test
import (
"fmt"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/integration"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
"google.golang.org/protobuf/types/known/timestamppb"
)
func createHumanUser(t *testing.T) *user.AddHumanUserResponse {
resp, err := Client.AddHumanUser(CTX, &user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Mickey",
LastName: "Mouse",
},
Email: &user.SetHumanEmail{
Email: fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()),
Verification: &user.SetHumanEmail_ReturnCode{
ReturnCode: &user.ReturnEmailVerificationCode{},
},
},
})
require.NoError(t, err)
require.NotEmpty(t, resp.GetUserId())
return resp
}
func TestServer_SetEmail(t *testing.T) {
userID := createHumanUser(t).GetUserId()
tests := []struct {
name string
req *user.SetEmailRequest
want *user.SetEmailResponse
wantErr bool
}{
{
name: "default verfication",
req: &user.SetEmailRequest{
UserId: userID,
Email: "default-verifier@mouse.com",
},
want: &user.SetEmailResponse{
Details: &object.Details{
Sequence: 1,
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "custom url template",
req: &user.SetEmailRequest{
UserId: userID,
Email: "custom-url@mouse.com",
Verification: &user.SetEmailRequest_SendCode{
SendCode: &user.SendEmailVerificationCode{
UrlTemplate: gu.Ptr("https://example.com/email/verify?userID={{.UserID}}&code={{.Code}}&orgID={{.OrgID}}"),
},
},
},
want: &user.SetEmailResponse{
Details: &object.Details{
Sequence: 1,
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "template error",
req: &user.SetEmailRequest{
UserId: userID,
Email: "custom-url@mouse.com",
Verification: &user.SetEmailRequest_SendCode{
SendCode: &user.SendEmailVerificationCode{
UrlTemplate: gu.Ptr("{{"),
},
},
},
wantErr: true,
},
{
name: "return code",
req: &user.SetEmailRequest{
UserId: userID,
Email: "return-code@mouse.com",
Verification: &user.SetEmailRequest_ReturnCode{
ReturnCode: &user.ReturnEmailVerificationCode{},
},
},
want: &user.SetEmailResponse{
Details: &object.Details{
Sequence: 1,
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
VerificationCode: gu.Ptr("xxx"),
},
},
{
name: "is verified true",
req: &user.SetEmailRequest{
UserId: userID,
Email: "verified-true@mouse.com",
Verification: &user.SetEmailRequest_IsVerified{
IsVerified: true,
},
},
want: &user.SetEmailResponse{
Details: &object.Details{
Sequence: 1,
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "is verified false",
req: &user.SetEmailRequest{
UserId: userID,
Email: "verified-false@mouse.com",
Verification: &user.SetEmailRequest_IsVerified{
IsVerified: false,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.SetEmail(CTX, tt.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
integration.AssertDetails(t, tt.want, got)
if tt.want.GetVerificationCode() != "" {
assert.NotEmpty(t, got.GetVerificationCode())
}
})
}
}
func TestServer_VerifyEmail(t *testing.T) {
userResp := createHumanUser(t)
tests := []struct {
name string
req *user.VerifyEmailRequest
want *user.VerifyEmailResponse
wantErr bool
}{
{
name: "wrong code",
req: &user.VerifyEmailRequest{
UserId: userResp.GetUserId(),
VerificationCode: "xxx",
},
wantErr: true,
},
{
name: "wrong user",
req: &user.VerifyEmailRequest{
UserId: "xxx",
VerificationCode: userResp.GetEmailCode(),
},
wantErr: true,
},
{
name: "verify user",
req: &user.VerifyEmailRequest{
UserId: userResp.GetUserId(),
VerificationCode: userResp.GetEmailCode(),
},
want: &user.VerifyEmailResponse{
Details: &object.Details{
Sequence: 1,
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.VerifyEmail(CTX, tt.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
integration.AssertDetails(t, tt.want, got)
})
}
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest) (_ *user.AddHumanUserResponse, err error) {
@@ -19,10 +19,7 @@ func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest
if err != nil {
return nil, err
}
orgID := req.GetOrganisation().GetOrgId()
if orgID == "" {
orgID = authz.GetCtxData(ctx).OrgID
}
orgID := authz.GetCtxData(ctx).OrgID
err = s.command.AddHuman(ctx, orgID, human, false)
if err != nil {
return nil, err

View File

@@ -0,0 +1,317 @@
//go:build integration
package user_test
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/integration"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
"google.golang.org/protobuf/types/known/timestamppb"
)
var (
CTX context.Context
ErrCTX context.Context
Tester *integration.Tester
Client user.UserServiceClient
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, errCtx, cancel := integration.Contexts(time.Hour)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
CTX, ErrCTX = Tester.WithSystemAuthorization(ctx, integration.OrgOwner), errCtx
Client = user.NewUserServiceClient(Tester.GRPCClientConn)
return m.Run()
}())
}
func TestServer_AddHumanUser(t *testing.T) {
type args struct {
ctx context.Context
req *user.AddHumanUserRequest
}
tests := []struct {
name string
args args
want *user.AddHumanUserResponse
wantErr bool
}{
{
name: "default verification",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Email: &user.SetHumanEmail{},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: true,
},
},
},
},
want: &user.AddHumanUserResponse{
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "return verification code",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Email: &user.SetHumanEmail{
Verification: &user.SetHumanEmail_ReturnCode{
ReturnCode: &user.ReturnEmailVerificationCode{},
},
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: true,
},
},
},
},
want: &user.AddHumanUserResponse{
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
EmailCode: gu.Ptr("something"),
},
},
{
name: "custom template",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Email: &user.SetHumanEmail{
Verification: &user.SetHumanEmail_SendCode{
SendCode: &user.SendEmailVerificationCode{
UrlTemplate: gu.Ptr("https://example.com/email/verify?userID={{.UserID}}&code={{.Code}}&orgID={{.OrgID}}"),
},
},
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: true,
},
},
},
},
want: &user.AddHumanUserResponse{
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "custom template error",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Email: &user.SetHumanEmail{
Verification: &user.SetHumanEmail_SendCode{
SendCode: &user.SendEmailVerificationCode{
UrlTemplate: gu.Ptr("{{"),
},
},
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: true,
},
},
},
},
wantErr: true,
},
{
name: "missing REQUIRED profile",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Email: &user.SetHumanEmail{
Verification: &user.SetHumanEmail_ReturnCode{
ReturnCode: &user.ReturnEmailVerificationCode{},
},
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: true,
},
},
},
},
wantErr: true,
},
{
name: "missing REQUIRED email",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: true,
},
},
},
},
wantErr: true,
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userID := fmt.Sprint(time.Now().UnixNano() + int64(i))
tt.args.req.UserId = &userID
if email := tt.args.req.GetEmail(); email != nil {
email.Email = fmt.Sprintf("%s@me.now", userID)
}
if tt.want != nil {
tt.want.UserId = userID
}
got, err := Client.AddHumanUser(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want.GetUserId(), got.GetUserId())
if tt.want.GetEmailCode() != "" {
assert.NotEmpty(t, got.GetEmailCode())
}
integration.AssertDetails(t, tt.want, got)
})
}
}

View File

@@ -23,6 +23,7 @@ const (
XUserAgent = "x-user-agent"
XGrpcWeb = "x-grpc-web"
XRequestedWith = "x-requested-with"
XRobotsTag = "x-robots-tag"
IfNoneMatch = "If-None-Match"
LastModified = "Last-Modified"
Etag = "Etag"

View File

@@ -43,12 +43,10 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
return next
}
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
var err error
tracingCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
@@ -63,8 +61,13 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
wrappedWriter.ignoreWrites = true
}
checkSpan.End()
next.ServeHTTP(wrappedWriter, request)
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
defer writeSpan.End()
requestURL := request.RequestURI
unescapedURL, err := url.QueryUnescape(requestURL)
if err != nil {

View File

@@ -62,7 +62,7 @@ func authorize(r *http.Request, verifier *authz.TokenVerifier, authConfig authz.
return nil, errors.New("auth header missing")
}
ctxSetter, err := authz.CheckUserAuthorization(authCtx, &httpReq{}, authToken, http_util.GetOrgID(r), verifier, authConfig, authOpt, r.RequestURI)
ctxSetter, err := authz.CheckUserAuthorization(authCtx, &httpReq{}, authToken, http_util.GetOrgID(r), "", verifier, authConfig, authOpt, r.RequestURI)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,14 @@
package middleware
import (
"net/http"
http_utils "github.com/zitadel/zitadel/internal/api/http"
)
func RobotsTagHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(http_utils.XRobotsTag, "none")
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,24 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_RobotsTagInterceptor(t *testing.T) {
testHandler := func(w http.ResponseWriter, r *http.Request) {}
req := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()
handler := RobotsTagHandler(http.HandlerFunc(testHandler))
handler.ServeHTTP(recorder, req)
res := recorder.Result()
exp := res.Header.Get("X-Robots-Tag")
assert.Equal(t, "none", exp)
defer res.Body.Close()
}

View File

@@ -0,0 +1,19 @@
package robots_txt
import (
"fmt"
"net/http"
)
const (
HandlerPrefix = "/robots.txt"
)
func Start() (http.Handler, error) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/text")
fmt.Fprintf(w, "User-agent: *\nDisallow: /\n")
})
return handler, nil
}

View File

@@ -0,0 +1,28 @@
package robots_txt
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_RobotsTxt(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
recorder := httptest.NewRecorder()
handler, err := Start()
handler.ServeHTTP(recorder, req)
assert.Equal(t, nil, err)
res := recorder.Result()
body, err := io.ReadAll(res.Body)
assert.Equal(t, nil, err)
assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "User-agent: *\nDisallow: /\n", string(body))
defer res.Body.Close()
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"github.com/dop251/goja"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/text/language"
@@ -14,6 +15,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/query"
)
func (l *Login) runPostExternalAuthenticationActions(
@@ -26,7 +28,21 @@ func (l *Login) runPostExternalAuthenticationActions(
) (_ *domain.ExternalUser, userChanged bool, err error) {
ctx := httpRequest.Context()
// use the request org (scopes or domain discovery) as default
resourceOwner := authRequest.RequestedOrgID
// if the user was already linked to an IDP and redirected to that, the requested org might be empty
if resourceOwner == "" {
resourceOwner = authRequest.UserOrgID
}
// if we will have no org (e.g. user clicked directly on the IDP on the login page)
if resourceOwner == "" {
// in this case the user might nevertheless already be linked to an IDP,
// so let's do a workaround and resourceOwnerOfUserIDPLink if there would be a IDP link
resourceOwner, err = l.resourceOwnerOfUserIDPLink(ctx, authRequest.SelectedIDPConfigID, user.ExternalUserID)
logging.WithFields("authReq", authRequest.ID, "idpID", authRequest.SelectedIDPConfigID).OnError(err).
Warn("could not determine resource owner for runPostExternalAuthenticationActions, fall back to default org id")
}
// fallback to default org id
if resourceOwner == "" {
resourceOwner = authz.GetInstance(ctx).DefaultOrganisationID()
}
@@ -394,3 +410,25 @@ func tokenCtxFields(tokens *oidc.Tokens[*oidc.IDTokenClaims]) []actions.FieldOpt
actions.SetFields("claimsJSON", claimsJSON),
}
}
func (l *Login) resourceOwnerOfUserIDPLink(ctx context.Context, idpConfigID string, externalUserID string) (string, error) {
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpConfigID)
if err != nil {
return "", err
}
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
if err != nil {
return "", err
}
queries := []query.SearchQuery{
idQuery, externalIDQuery,
}
links, err := l.query.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, false)
if err != nil {
return "", err
}
if len(links.Links) != 1 {
return "", nil
}
return links.Links[0].ResourceOwner, nil
}

View File

@@ -23,6 +23,7 @@
<title>{{ .Title }}</title>
<meta name="description" content="{{ .Description }}"/>
<meta name="robots" content="none" />
<script src="{{ resourceUrl "scripts/theme.js" }}"></script>
</head>

View File

@@ -20,6 +20,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
usr_grant_repo "github.com/zitadel/zitadel/internal/repository/usergrant"
"github.com/zitadel/zitadel/internal/static"
@@ -29,7 +30,7 @@ import (
type Commands struct {
httpClient *http.Client
checkPermission permissionCheck
checkPermission domain.PermissionCheck
newEmailCode func(ctx context.Context, filter preparation.FilterToQueryReducer, codeAlg crypto.EncryptionAlgorithm) (*CryptoCodeWithExpiry, error)
eventstore *eventstore.Eventstore
@@ -50,6 +51,8 @@ type Commands struct {
domainVerificationAlg crypto.EncryptionAlgorithm
domainVerificationGenerator crypto.Generator
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
sessionTokenCreator func(sessionID string) (id string, token string, err error)
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
multifactors domain.MultifactorConfigs
webauthnConfig *webauthn_helper.Config
@@ -71,24 +74,21 @@ func StartCommands(
externalDomain string,
externalSecure bool,
externalPort uint16,
idpConfigEncryption,
otpEncryption,
smtpEncryption,
smsEncryption,
userEncryption,
domainVerificationEncryption,
oidcEncryption,
samlEncryption crypto.EncryptionAlgorithm,
idpConfigEncryption, otpEncryption, smtpEncryption, smsEncryption, userEncryption, domainVerificationEncryption, oidcEncryption, samlEncryption crypto.EncryptionAlgorithm,
httpClient *http.Client,
membershipsResolver authz.MembershipsResolver,
permissionCheck domain.PermissionCheck,
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
) (repo *Commands, err error) {
if externalDomain == "" {
return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified")
}
idGenerator := id.SonyFlakeGenerator()
// reuse the oidcEncryption to be able to handle both tokens in the interceptor later on
sessionAlg := oidcEncryption
repo = &Commands{
eventstore: es,
static: staticStore,
idGenerator: id.SonyFlakeGenerator(),
idGenerator: idGenerator,
zitadelRoles: zitadelRoles,
externalDomain: externalDomain,
externalSecure: externalSecure,
@@ -107,10 +107,10 @@ func StartCommands(
certificateAlgorithm: samlEncryption,
webauthnConfig: webAuthN,
httpClient: httpClient,
checkPermission: func(ctx context.Context, permission, orgID, resourceID string, allowSelf bool) (err error) {
return authz.CheckPermission(ctx, membershipsResolver, zitadelRoles, permission, orgID, resourceID, allowSelf)
},
newEmailCode: newEmailCode,
checkPermission: permissionCheck,
newEmailCode: newEmailCode,
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
sessionTokenVerifier: sessionTokenVerifier,
}
instance_repo.RegisterEventMappers(repo.eventstore)
@@ -121,6 +121,7 @@ func StartCommands(
keypair.RegisterEventMappers(repo.eventstore)
action.RegisterEventMappers(repo.eventstore)
quota.RegisterEventMappers(repo.eventstore)
session.RegisterEventMappers(repo.eventstore)
repo.userPasswordAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost)
repo.machineKeySize = int(defaults.SecretGenerators.MachineKeySize)

View File

@@ -10,6 +10,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
@@ -19,6 +20,7 @@ import (
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
)
@@ -38,6 +40,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
usergrant.RegisterEventMappers(es)
key_repo.RegisterEventMappers(es)
action_repo.RegisterEventMappers(es)
session.RegisterEventMappers(es)
return es
}
@@ -125,6 +128,11 @@ func expectFilter(events ...*repository.Event) expect {
m.ExpectFilterEvents(events...)
}
}
func expectFilterError(err error) expect {
return func(m *mock.MockRepository) {
m.ExpectFilterEventsError(err)
}
}
func expectFilterOrgDomainNotFound() expect {
return func(m *mock.MockRepository) {
@@ -250,14 +258,14 @@ func (m *mockInstance) SecurityPolicyAllowedOrigins() []string {
return nil
}
func newMockPermissionCheckAllowed() permissionCheck {
return func(ctx context.Context, permission, orgID, resourceID string, allowSelf bool) (err error) {
func newMockPermissionCheckAllowed() domain.PermissionCheck {
return func(ctx context.Context, permission, orgID, resourceID string) (err error) {
return nil
}
}
func newMockPermissionCheckNotAllowed() permissionCheck {
return func(ctx context.Context, permission, orgID, resourceID string, allowSelf bool) (err error) {
func newMockPermissionCheckNotAllowed() domain.PermissionCheck {
return func(ctx context.Context, permission, orgID, resourceID string) (err error) {
return errors.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied")
}
}

View File

@@ -1,11 +0,0 @@
package command
import (
"context"
)
type permissionCheck func(ctx context.Context, permission, orgID, resourceID string, allowSelf bool) (err error)
const (
permissionUserWrite = "user.write"
)

225
internal/command/session.go Normal file
View File

@@ -0,0 +1,225 @@
package command
import (
"context"
"encoding/base64"
"fmt"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type SessionCheck func(ctx context.Context, cmd *SessionChecks) error
type SessionChecks struct {
checks []SessionCheck
sessionWriteModel *SessionWriteModel
passwordWriteModel *HumanPasswordWriteModel
eventstore *eventstore.Eventstore
userPasswordAlg crypto.HashAlgorithm
createToken func(sessionID string) (id string, token string, err error)
now func() time.Time
}
func (c *Commands) NewSessionChecks(checks []SessionCheck, session *SessionWriteModel) *SessionChecks {
return &SessionChecks{
checks: checks,
sessionWriteModel: session,
eventstore: c.eventstore,
userPasswordAlg: c.userPasswordAlg,
createToken: c.sessionTokenCreator,
now: time.Now,
}
}
// CheckUser defines a user check to be executed for a session update
func CheckUser(id string) SessionCheck {
return func(ctx context.Context, cmd *SessionChecks) error {
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
return caos_errs.ThrowInvalidArgument(nil, "", "user change not possible")
}
return cmd.sessionWriteModel.UserChecked(ctx, id, cmd.now())
}
}
// CheckPassword defines a password check to be executed for a session update
func CheckPassword(password string) SessionCheck {
return func(ctx context.Context, cmd *SessionChecks) error {
if cmd.sessionWriteModel.UserID == "" {
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
}
cmd.passwordWriteModel = NewHumanPasswordWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.passwordWriteModel)
if err != nil {
return err
}
if cmd.passwordWriteModel.UserState == domain.UserStateUnspecified || cmd.passwordWriteModel.UserState == domain.UserStateDeleted {
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Df4b3", "Errors.User.NotFound")
}
if cmd.passwordWriteModel.Secret == nil {
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-WEf3t", "Errors.User.Password.NotSet")
}
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "crypto.CompareHash")
err = crypto.CompareHash(cmd.passwordWriteModel.Secret, []byte(password), cmd.userPasswordAlg)
spanPasswordComparison.EndWithError(err)
if err != nil {
//TODO: maybe we want to reset the session in the future https://github.com/zitadel/zitadel/issues/5807
return caos_errs.ThrowInvalidArgument(err, "COMMAND-SAF3g", "Errors.User.Password.Invalid")
}
cmd.sessionWriteModel.PasswordChecked(ctx, cmd.now())
return nil
}
}
// Check will execute the checks specified and return an error on the first occurrence
func (s *SessionChecks) Check(ctx context.Context) error {
for _, check := range s.checks {
if err := check(ctx, s); err != nil {
return err
}
}
return nil
}
func (s *SessionChecks) commands(ctx context.Context) (string, []eventstore.Command, error) {
if len(s.sessionWriteModel.commands) == 0 {
return "", nil, nil
}
tokenID, token, err := s.createToken(s.sessionWriteModel.AggregateID)
if err != nil {
return "", nil, err
}
s.sessionWriteModel.SetToken(ctx, tokenID)
return token, s.sessionWriteModel.commands, nil
}
func (c *Commands) CreateSession(ctx context.Context, checks []SessionCheck, metadata map[string][]byte) (set *SessionChanged, err error) {
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, err
}
cmd := c.NewSessionChecks(checks, sessionWriteModel)
cmd.sessionWriteModel.Start(ctx)
return c.updateSession(ctx, cmd, metadata)
}
func (c *Commands) UpdateSession(ctx context.Context, sessionID, sessionToken string, checks []SessionCheck, metadata map[string][]byte) (set *SessionChanged, err error) {
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, err
}
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil {
return nil, err
}
cmd := c.NewSessionChecks(checks, sessionWriteModel)
return c.updateSession(ctx, cmd, metadata)
}
func (c *Commands) TerminateSession(ctx context.Context, sessionID, sessionToken string) (*domain.ObjectDetails, error) {
sessionWriteModel := NewSessionWriteModel(sessionID, "")
if err := c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil {
return nil, err
}
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionDelete); err != nil {
return nil, err
}
if sessionWriteModel.State != domain.SessionStateActive {
return writeModelToObjectDetails(&sessionWriteModel.WriteModel), nil
}
terminate := session.NewTerminateEvent(ctx, &session.NewAggregate(sessionWriteModel.AggregateID, sessionWriteModel.ResourceOwner).Aggregate)
pushedEvents, err := c.eventstore.Push(ctx, terminate)
if err != nil {
return nil, err
}
err = AppendAndReduce(sessionWriteModel, pushedEvents...)
if err != nil {
return nil, err
}
return writeModelToObjectDetails(&sessionWriteModel.WriteModel), nil
}
// updateSession execute the [SessionChecks] where new events will be created and as well as for metadata (changes)
func (c *Commands) updateSession(ctx context.Context, checks *SessionChecks, metadata map[string][]byte) (set *SessionChanged, err error) {
if checks.sessionWriteModel.State == domain.SessionStateTerminated {
return nil, caos_errs.ThrowPreconditionFailed(nil, "COMAND-SAjeh", "Errors.Session.Terminated")
}
if err := checks.Check(ctx); err != nil {
// TODO: how to handle failed checks (e.g. pw wrong) https://github.com/zitadel/zitadel/issues/5807
return nil, err
}
checks.sessionWriteModel.ChangeMetadata(ctx, metadata)
sessionToken, cmds, err := checks.commands(ctx)
if err != nil {
return nil, err
}
if len(cmds) == 0 {
return sessionWriteModelToSessionChanged(checks.sessionWriteModel), nil
}
pushedEvents, err := c.eventstore.Push(ctx, cmds...)
if err != nil {
return nil, err
}
err = AppendAndReduce(checks.sessionWriteModel, pushedEvents...)
if err != nil {
return nil, err
}
changed := sessionWriteModelToSessionChanged(checks.sessionWriteModel)
changed.NewToken = sessionToken
return changed, nil
}
// sessionPermission will check that the provided sessionToken is correct or
// if empty, check that the caller is granted the necessary permission
func (c *Commands) sessionPermission(ctx context.Context, sessionWriteModel *SessionWriteModel, sessionToken, permission string) (err error) {
if sessionToken == "" {
return c.checkPermission(ctx, permission, authz.GetCtxData(ctx).OrgID, sessionWriteModel.AggregateID)
}
return c.sessionTokenVerifier(ctx, sessionToken, sessionWriteModel.AggregateID, sessionWriteModel.TokenID)
}
func sessionTokenCreator(idGenerator id.Generator, sessionAlg crypto.EncryptionAlgorithm) func(sessionID string) (id string, token string, err error) {
return func(sessionID string) (id string, token string, err error) {
id, err = idGenerator.Next()
if err != nil {
return "", "", err
}
encrypted, err := sessionAlg.Encrypt([]byte(fmt.Sprintf(authz.SessionTokenFormat, sessionID, id)))
if err != nil {
return "", "", err
}
return id, base64.RawURLEncoding.EncodeToString(encrypted), nil
}
}
type SessionChanged struct {
*domain.ObjectDetails
ID string
NewToken string
}
func sessionWriteModelToSessionChanged(wm *SessionWriteModel) *SessionChanged {
return &SessionChanged{
ObjectDetails: &domain.ObjectDetails{
Sequence: wm.ProcessedSequence,
EventDate: wm.ChangeDate,
ResourceOwner: wm.ResourceOwner,
},
ID: wm.AggregateID,
}
}

View File

@@ -0,0 +1,139 @@
package command
import (
"bytes"
"context"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/session"
)
type SessionWriteModel struct {
eventstore.WriteModel
TokenID string
UserID string
UserCheckedAt time.Time
PasswordCheckedAt time.Time
Metadata map[string][]byte
State domain.SessionState
commands []eventstore.Command
aggregate *eventstore.Aggregate
}
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
return &SessionWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: sessionID,
ResourceOwner: resourceOwner,
},
Metadata: make(map[string][]byte),
aggregate: &session.NewAggregate(sessionID, resourceOwner).Aggregate,
}
}
func (wm *SessionWriteModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *session.AddedEvent:
wm.reduceAdded(e)
case *session.UserCheckedEvent:
wm.reduceUserChecked(e)
case *session.PasswordCheckedEvent:
wm.reducePasswordChecked(e)
case *session.TokenSetEvent:
wm.reduceTokenSet(e)
case *session.TerminateEvent:
wm.reduceTerminate()
}
}
return wm.WriteModel.Reduce()
}
func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
session.AddedType,
session.UserCheckedType,
session.PasswordCheckedType,
session.TokenSetType,
session.MetadataSetType,
session.TerminateType,
).
Builder()
if wm.ResourceOwner != "" {
query.ResourceOwner(wm.ResourceOwner)
}
return query
}
func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) {
wm.State = domain.SessionStateActive
}
func (wm *SessionWriteModel) reduceUserChecked(e *session.UserCheckedEvent) {
wm.UserID = e.UserID
wm.UserCheckedAt = e.CheckedAt
}
func (wm *SessionWriteModel) reducePasswordChecked(e *session.PasswordCheckedEvent) {
wm.PasswordCheckedAt = e.CheckedAt
}
func (wm *SessionWriteModel) reduceTokenSet(e *session.TokenSetEvent) {
wm.TokenID = e.TokenID
}
func (wm *SessionWriteModel) reduceTerminate() {
wm.State = domain.SessionStateTerminated
}
func (wm *SessionWriteModel) Start(ctx context.Context) {
wm.commands = append(wm.commands, session.NewAddedEvent(ctx, wm.aggregate))
}
func (wm *SessionWriteModel) UserChecked(ctx context.Context, userID string, checkedAt time.Time) error {
wm.commands = append(wm.commands, session.NewUserCheckedEvent(ctx, wm.aggregate, userID, checkedAt))
// set the userID so other checks can use it
wm.UserID = userID
return nil
}
func (wm *SessionWriteModel) PasswordChecked(ctx context.Context, checkedAt time.Time) {
wm.commands = append(wm.commands, session.NewPasswordCheckedEvent(ctx, wm.aggregate, checkedAt))
}
func (wm *SessionWriteModel) SetToken(ctx context.Context, tokenID string) {
wm.commands = append(wm.commands, session.NewTokenSetEvent(ctx, wm.aggregate, tokenID))
}
func (wm *SessionWriteModel) ChangeMetadata(ctx context.Context, metadata map[string][]byte) {
var changed bool
for key, value := range metadata {
currentValue, exists := wm.Metadata[key]
if len(value) != 0 {
// if a value is provided, and it's not equal, change it
if !bytes.Equal(currentValue, value) {
wm.Metadata[key] = value
changed = true
}
} else {
// if there's no / an empty value, we only need to remove it on existing entries
if exists {
delete(wm.Metadata, key)
changed = true
}
}
}
if changed {
wm.commands = append(wm.commands, session.NewMetadataSetEvent(ctx, wm.aggregate, wm.Metadata))
}
}

View File

@@ -0,0 +1,547 @@
package command
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
)
func TestCommands_CreateSession(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
tokenCreator func(sessionID string) (string, string, error)
}
type args struct {
ctx context.Context
checks []SessionCheck
metadata map[string][]byte
}
type res struct {
want *SessionChanged
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"id generator fails",
fields{
idGenerator: mock.NewIDGeneratorExpectError(t, caos_errs.ThrowInternal(nil, "id", "generator failed")),
},
args{
ctx: context.Background(),
},
res{
err: caos_errs.ThrowInternal(nil, "id", "generator failed"),
},
},
{
"eventstore failed",
fields{
idGenerator: mock.NewIDGeneratorExpectIDs(t, "sessionID"),
eventstore: eventstoreExpect(t,
expectFilterError(caos_errs.ThrowInternal(nil, "id", "filter failed")),
),
},
args{
ctx: context.Background(),
},
res{
err: caos_errs.ThrowInternal(nil, "id", "filter failed"),
},
},
{
"empty session",
fields{
idGenerator: mock.NewIDGeneratorExpectIDs(t, "sessionID"),
eventstore: eventstoreExpect(t,
expectFilter(),
expectPush(
eventPusherToEvents(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID",
),
),
),
),
tokenCreator: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
},
args{
ctx: authz.NewMockContext("", "org1", ""),
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{ResourceOwner: "org1"},
ID: "sessionID",
NewToken: "token",
},
},
},
// the rest is tested in the Test_updateSession
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
sessionTokenCreator: tt.fields.tokenCreator,
}
got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
}
}
func TestCommands_UpdateSession(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
}
type args struct {
ctx context.Context
sessionID string
sessionToken string
checks []SessionCheck
metadata map[string][]byte
}
type res struct {
want *SessionChanged
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"eventstore failed",
fields{
eventstore: eventstoreExpect(t,
expectFilterError(caos_errs.ThrowInternal(nil, "id", "filter failed")),
),
},
args{
ctx: context.Background(),
},
res{
err: caos_errs.ThrowInternal(nil, "id", "filter failed"),
},
},
{
"invalid session token",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
},
},
args{
ctx: context.Background(),
sessionID: "sessionID",
sessionToken: "invalid",
},
res{
err: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
},
},
{
"no change",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
},
args{
ctx: context.Background(),
sessionID: "sessionID",
sessionToken: "token",
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{
ResourceOwner: "org1",
},
ID: "sessionID",
NewToken: "",
},
},
},
// the rest is tested in the Test_updateSession
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
sessionTokenVerifier: tt.fields.tokenVerifier,
}
got, err := c.UpdateSession(tt.args.ctx, tt.args.sessionID, tt.args.sessionToken, tt.args.checks, tt.args.metadata)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
}
}
func TestCommands_updateSession(t *testing.T) {
testNow := time.Now()
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
checks *SessionChecks
metadata map[string][]byte
}
type res struct {
want *SessionChanged
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"terminated",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: context.Background(),
checks: &SessionChecks{
sessionWriteModel: &SessionWriteModel{State: domain.SessionStateTerminated},
},
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "COMAND-SAjeh", "Errors.Session.Terminated"),
},
},
{
"check failed",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: context.Background(),
checks: &SessionChecks{
sessionWriteModel: NewSessionWriteModel("sessionID", "org1"),
checks: []SessionCheck{
func(ctx context.Context, cmd *SessionChecks) error {
return caos_errs.ThrowInternal(nil, "id", "check failed")
},
},
},
},
res{
err: caos_errs.ThrowInternal(nil, "id", "check failed"),
},
},
{
"no change",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: context.Background(),
checks: &SessionChecks{
sessionWriteModel: NewSessionWriteModel("sessionID", "org1"),
checks: []SessionCheck{},
},
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{
ResourceOwner: "org1",
},
ID: "sessionID",
NewToken: "",
},
},
},
{
"set user, password, metadata and token",
fields{
eventstore: eventstoreExpect(t,
expectPush(
eventPusherToEvents(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"userID", testNow),
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
session.NewMetadataSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
map[string][]byte{"key": []byte("value")}),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"),
),
),
),
},
args{
ctx: context.Background(),
checks: &SessionChecks{
sessionWriteModel: NewSessionWriteModel("sessionID", "org1"),
checks: []SessionCheck{
CheckUser("userID"),
CheckPassword("password"),
},
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
"username", "", "", "", "", language.English, domain.GenderUnspecified, "", false),
),
eventFromEventPusher(
user.NewHumanPasswordChangedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
&crypto.CryptoValue{
CryptoType: crypto.TypeHash,
Algorithm: "hash",
KeyID: "",
Crypted: []byte("password"),
}, false, ""),
),
),
),
createToken: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
userPasswordAlg: crypto.CreateMockHashAlg(gomock.NewController(t)),
now: func() time.Time {
return testNow
},
},
metadata: map[string][]byte{
"key": []byte("value"),
},
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{
ResourceOwner: "org1",
},
ID: "sessionID",
NewToken: "token",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := c.updateSession(tt.args.ctx, tt.args.checks, tt.args.metadata)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
}
}
func TestCommands_TerminateSession(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
}
type args struct {
ctx context.Context
sessionID string
sessionToken string
}
type res struct {
want *domain.ObjectDetails
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"eventstore failed",
fields{
eventstore: eventstoreExpect(t,
expectFilterError(caos_errs.ThrowInternal(nil, "id", "filter failed")),
),
},
args{
ctx: context.Background(),
},
res{
err: caos_errs.ThrowInternal(nil, "id", "filter failed"),
},
},
{
"invalid session token",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
},
},
args{
ctx: context.Background(),
sessionID: "sessionID",
sessionToken: "invalid",
},
res{
err: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
},
},
{
"not active",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
eventFromEventPusher(
session.NewTerminateEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
},
args{
ctx: context.Background(),
sessionID: "sessionID",
sessionToken: "token",
},
res{
want: &domain.ObjectDetails{
ResourceOwner: "org1",
},
},
},
{
"push failed",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"),
),
),
expectPushFailed(
caos_errs.ThrowInternal(nil, "id", "pushed failed"),
eventPusherToEvents(
session.NewTerminateEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
},
args{
ctx: context.Background(),
sessionID: "sessionID",
sessionToken: "token",
},
res{
err: caos_errs.ThrowInternal(nil, "id", "pushed failed"),
},
},
{
"terminate",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"),
),
),
expectPush(
eventPusherToEvents(
session.NewTerminateEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
},
args{
ctx: context.Background(),
sessionID: "sessionID",
sessionToken: "token",
},
res{
want: &domain.ObjectDetails{
ResourceOwner: "org1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
sessionTokenVerifier: tt.fields.tokenVerifier,
}
got, err := c.TerminateSession(tt.args.ctx, tt.args.sessionID, tt.args.sessionToken)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
@@ -42,7 +43,7 @@ func (c *Commands) ChangeUserEmailVerified(ctx context.Context, userID, resource
if err != nil {
return nil, err
}
if err = c.checkPermission(ctx, permissionUserWrite, cmd.aggregate.ResourceOwner, userID, false); err != nil {
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
return nil, err
}
if err = cmd.Change(ctx, domain.EmailAddress(email)); err != nil {
@@ -70,8 +71,10 @@ func (c *Commands) changeUserEmailWithGenerator(ctx context.Context, userID, res
if err != nil {
return nil, err
}
if err = c.checkPermission(ctx, permissionUserWrite, cmd.aggregate.ResourceOwner, userID, true); err != nil {
return nil, err
if authz.GetCtxData(ctx).UserID != userID {
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
return nil, err
}
}
if err = cmd.Change(ctx, domain.EmailAddress(email)); err != nil {
return nil, err

View File

@@ -24,7 +24,7 @@ import (
func TestCommands_ChangeUserEmail(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
checkPermission permissionCheck
checkPermission domain.PermissionCheck
}
type args struct {
userID string
@@ -174,7 +174,7 @@ func TestCommands_ChangeUserEmail(t *testing.T) {
func TestCommands_ChangeUserEmailURLTemplate(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
checkPermission permissionCheck
checkPermission domain.PermissionCheck
}
type args struct {
userID string
@@ -300,7 +300,7 @@ func TestCommands_ChangeUserEmailURLTemplate(t *testing.T) {
func TestCommands_ChangeUserEmailReturnCode(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
checkPermission permissionCheck
checkPermission domain.PermissionCheck
}
type args struct {
userID string
@@ -410,7 +410,7 @@ func TestCommands_ChangeUserEmailReturnCode(t *testing.T) {
func TestCommands_ChangeUserEmailVerified(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
checkPermission permissionCheck
checkPermission domain.PermissionCheck
}
type args struct {
userID string
@@ -569,7 +569,7 @@ func TestCommands_ChangeUserEmailVerified(t *testing.T) {
func TestCommands_changeUserEmailWithGenerator(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
checkPermission permissionCheck
checkPermission domain.PermissionCheck
}
type args struct {
userID string

View File

@@ -2,13 +2,14 @@ package database
import (
"database/sql/driver"
"encoding/json"
"github.com/jackc/pgtype"
)
type StringArray []string
// Scan implements the `database/sql.Scanner` interface.
// Scan implements the [database/sql.Scanner] interface.
func (s *StringArray) Scan(src any) error {
array := new(pgtype.TextArray)
if err := array.Scan(src); err != nil {
@@ -20,7 +21,7 @@ func (s *StringArray) Scan(src any) error {
return nil
}
// Value implements the `database/sql/driver.Valuer` interface.
// Value implements the [database/sql/driver.Valuer] interface.
func (s StringArray) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
@@ -40,7 +41,7 @@ type enumField interface {
type EnumArray[F enumField] []F
// Scan implements the `database/sql.Scanner` interface.
// Scan implements the [database/sql.Scanner] interface.
func (s *EnumArray[F]) Scan(src any) error {
array := new(pgtype.Int2Array)
if err := array.Scan(src); err != nil {
@@ -57,7 +58,7 @@ func (s *EnumArray[F]) Scan(src any) error {
return nil
}
// Value implements the `database/sql/driver.Valuer` interface.
// Value implements the [database/sql/driver.Valuer] interface.
func (s EnumArray[F]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
@@ -70,3 +71,25 @@ func (s EnumArray[F]) Value() (driver.Value, error) {
return array.Value()
}
type Map[V any] map[string]V
// Scan implements the [database/sql.Scanner] interface.
func (m *Map[V]) Scan(src any) error {
bytea := new(pgtype.Bytea)
if err := bytea.Scan(src); err != nil {
return err
}
if len(bytea.Bytes) == 0 {
return nil
}
return json.Unmarshal(bytea.Bytes, &m)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (m Map[V]) Value() (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
return json.Marshal(m)
}

View File

@@ -0,0 +1,119 @@
package database
import (
"database/sql/driver"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMap_Scan(t *testing.T) {
type args struct {
src any
}
type res[V any] struct {
want Map[V]
err bool
}
type testCase[V any] struct {
name string
m Map[V]
args args
res[V]
}
tests := []testCase[string]{
{
"null",
Map[string]{},
args{src: "invalid"},
res[string]{
want: Map[string]{},
err: true,
},
},
{
"null",
Map[string]{},
args{src: nil},
res[string]{
want: Map[string]{},
},
},
{
"empty",
Map[string]{},
args{src: []byte(`{}`)},
res[string]{
want: Map[string]{},
},
},
{
"set",
Map[string]{},
args{src: []byte(`{"key": "value"}`)},
res[string]{
want: Map[string]{
"key": "value",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestMap_Value(t *testing.T) {
type res struct {
want driver.Value
err bool
}
type testCase[V any] struct {
name string
m Map[V]
res res
}
tests := []testCase[string]{
{
"nil",
nil,
res{
want: nil,
},
},
{
"empty",
Map[string]{},
res{
want: nil,
},
},
{
"set",
Map[string]{
"key": "value",
},
res{
want: driver.Value([]byte(`{"key":"value"}`)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.Value()
if tt.res.err {
assert.Error(t, err)
}
if !tt.res.err {
require.NoError(t, err)
assert.Equalf(t, tt.res.want, got, "Value()")
}
})
}
}

View File

@@ -1,5 +1,7 @@
package domain
import "context"
type Permissions struct {
Permissions []string
}
@@ -21,3 +23,12 @@ func (p *Permissions) appendPermission(ctxID, permission string) {
}
p.Permissions = append(p.Permissions, permission)
}
type PermissionCheck func(ctx context.Context, permission, orgID, resourceID string) (err error)
const (
PermissionUserWrite = "user.write"
PermissionSessionRead = "session.read"
PermissionSessionWrite = "session.write"
PermissionSessionDelete = "session.delete"
)

View File

@@ -0,0 +1,9 @@
package domain
type SessionState int32
const (
SessionStateUnspecified SessionState = iota
SessionStateActive
SessionStateTerminated
)

View File

@@ -25,3 +25,9 @@ func ExpectID(t *testing.T, id string) *MockGenerator {
m.EXPECT().Next().Return(id, nil)
return m
}
func NewIDGeneratorExpectError(t *testing.T, err error) *MockGenerator {
m := NewMockGenerator(gomock.NewController(t))
m.EXPECT().Next().Return("", err)
return m
}

View File

@@ -0,0 +1,41 @@
package integration
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
)
type DetailsMsg interface {
GetDetails() *object.Details
}
// AssertDetails asserts values in a message's object Details,
// if the object Details in expected is a non-nil value.
// It targets API v2 messages that have the `GetDetails()` method.
//
// Dynamically generated values are not compared with expected.
// Instead a sanity check is performed.
// For the sequence a non-zero value is expected.
// The change date has to be now, with a tollerance of 1 second.
//
// The resource owner is compared with expected and is
// therefore the only value that has to be set.
func AssertDetails[D DetailsMsg](t testing.TB, exptected, actual D) {
wantDetails, gotDetails := exptected.GetDetails(), actual.GetDetails()
if wantDetails == nil {
assert.Nil(t, gotDetails)
return
}
assert.NotZero(t, gotDetails.GetSequence())
gotCD := gotDetails.GetChangeDate().AsTime()
now := time.Now()
assert.WithinRange(t, gotCD, now.Add(-time.Second), now.Add(time.Second))
assert.Equal(t, wantDetails.GetResourceOwner(), gotDetails.GetResourceOwner())
}

View File

@@ -0,0 +1,51 @@
package integration
import (
"testing"
"google.golang.org/protobuf/types/known/timestamppb"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
)
type myMsg struct {
details *object.Details
}
func (m myMsg) GetDetails() *object.Details {
return m.details
}
func TestAssertDetails(t *testing.T) {
tests := []struct {
name string
exptected myMsg
actual myMsg
}{
{
name: "nil",
exptected: myMsg{},
actual: myMsg{},
},
{
name: "values",
exptected: myMsg{
details: &object.Details{
ResourceOwner: "me",
},
},
actual: myMsg{
details: &object.Details{
Sequence: 123,
ChangeDate: timestamppb.Now(),
ResourceOwner: "me",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertDetails(t, tt.exptected, tt.actual)
})
}
}

View File

@@ -0,0 +1,10 @@
Database:
cockroach:
Host: localhost
Port: 26257
Database: zitadel
Options: ""
User:
Username: zitadel
Admin:
Username: root

View File

@@ -0,0 +1,24 @@
version: '3.8'
services:
cockroach:
extends:
file: '../../../e2e/config/localhost/docker-compose.yaml'
service: 'db'
postgres:
restart: 'always'
image: 'postgres:15'
environment:
- POSTGRES_USER=zitadel
- PGUSER=zitadel
- POSTGRES_DB=zitadel
- POSTGRES_HOST_AUTH_METHOD=trust
healthcheck:
test: ["CMD-SHELL", "pg_isready"]
interval: '10s'
timeout: '30s'
retries: 5
start_period: '20s'
ports:
- 5432:5432

View File

@@ -0,0 +1,15 @@
Database:
postgres:
Host: localhost
Port: 5432
Database: zitadel
MaxOpenConns: 20
MaxIdleConns: 10
User:
Username: zitadel
SSL:
Mode: disable
Admin:
Username: zitadel
SSL:
Mode: disable

View File

@@ -0,0 +1,37 @@
Log:
Level: debug
TLS:
Enabled: false
FirstInstance:
Org:
Human:
PasswordChangeRequired: false
LogStore:
Access:
Database:
Enabled: true
Debounce:
MinFrequency: 0s
MaxBulkSize: 0
Execution:
Database:
Enabled: true
Stdout:
Enabled: true
Quotas:
Access:
ExhaustedCookieKey: "zitadel.quota.limiting"
ExhaustedCookieMaxAge: "60s"
Projections:
Customizations:
NotificationsQuotas:
RequeueEvery: 1s
DefaultInstance:
LoginPolicy:
MfaInitSkipLifetime: "0"

View File

@@ -0,0 +1,250 @@
// Package integration provides helpers for integration testing.
package integration
import (
"bytes"
"context"
"database/sql"
_ "embed"
"errors"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/spf13/viper"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v2/pkg/oidc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"github.com/zitadel/zitadel/cmd"
"github.com/zitadel/zitadel/cmd/start"
"github.com/zitadel/zitadel/internal/api/authz"
z_oidc "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/pkg/grpc/admin"
)
var (
//go:embed config/zitadel.yaml
zitadelYAML []byte
//go:embed config/cockroach.yaml
cockroachYAML []byte
//go:embed config/postgres.yaml
postgresYAML []byte
)
// UserType provides constants that give
// a short explinanation with the purpose
// a serverice user.
// This allows to pre-create users with
// different permissions and reuse them.
type UserType int
//go:generate stringer -type=UserType
const (
Unspecified UserType = iota
OrgOwner
)
// User information with a Personal Access Token.
type User struct {
*query.User
Token string
}
// Tester is a Zitadel server and client with all resources available for testing.
type Tester struct {
*start.Server
Instance authz.Instance
Organisation *query.Org
Users map[UserType]User
GRPCClientConn *grpc.ClientConn
wg sync.WaitGroup // used for shutdown
}
const commandLine = `start --masterkeyFromEnv`
func (s *Tester) Host() string {
return fmt.Sprintf("%s:%d", s.Config.ExternalDomain, s.Config.Port)
}
func (s *Tester) createClientConn(ctx context.Context) {
target := s.Host()
cc, err := grpc.DialContext(ctx, target,
grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
s.Shutdown <- os.Interrupt
s.wg.Wait()
}
logging.OnError(err).Fatal("integration tester client dial")
logging.New().WithField("target", target).Info("finished dialing grpc client conn")
s.GRPCClientConn = cc
err = s.pollHealth(ctx)
logging.OnError(err).Fatal("integration tester health")
}
// pollHealth waits until a healthy status is reported.
// TODO: remove when we make the setup blocking on all
// projections completed.
func (s *Tester) pollHealth(ctx context.Context) (err error) {
client := admin.NewAdminServiceClient(s.GRPCClientConn)
for {
err = func(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_, err := client.Healthz(ctx, &admin.HealthzRequest{})
return err
}(ctx)
if err == nil {
return nil
}
logging.WithError(err).Info("poll healthz")
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
continue
}
}
}
const (
SystemUser = "integration"
)
func (s *Tester) createSystemUser(ctx context.Context) {
var err error
s.Instance, err = s.Queries.InstanceByHost(ctx, s.Host())
logging.OnError(err).Fatal("query instance")
ctx = authz.WithInstance(ctx, s.Instance)
s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID())
logging.OnError(err).Fatal("query organisation")
query, err := query.NewUserUsernameSearchQuery(SystemUser, query.TextEquals)
logging.OnError(err).Fatal("user query")
user, err := s.Queries.GetUser(ctx, true, true, query)
if errors.Is(err, sql.ErrNoRows) {
_, err = s.Commands.AddMachine(ctx, &command.Machine{
ObjectRoot: models.ObjectRoot{
ResourceOwner: s.Organisation.ID,
},
Username: SystemUser,
Name: SystemUser,
Description: "who cares?",
AccessTokenType: domain.OIDCTokenTypeJWT,
})
logging.OnError(err).Fatal("add machine user")
user, err = s.Queries.GetUser(ctx, true, true, query)
}
logging.OnError(err).Fatal("get user")
_, err = s.Commands.AddOrgMember(ctx, s.Organisation.ID, user.ID, "ORG_OWNER")
target := new(caos_errs.AlreadyExistsError)
if !errors.As(err, &target) {
logging.OnError(err).Fatal("add org member")
}
scopes := []string{oidc.ScopeOpenID, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner}
pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine)
_, err = s.Commands.AddPersonalAccessToken(ctx, pat)
logging.OnError(err).Fatal("add pat")
s.Users = map[UserType]User{
OrgOwner: {
User: user,
Token: pat.Token,
},
}
}
func (s *Tester) WithSystemAuthorization(ctx context.Context, u UserType) context.Context {
return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users[u].Token))
}
// Done send an interrupt signal to cleanly shutdown the server.
func (s *Tester) Done() {
err := s.GRPCClientConn.Close()
logging.OnError(err).Error("integration tester client close")
s.Shutdown <- os.Interrupt
s.wg.Wait()
}
// NewTester start a new Zitadel server by passing the default commandline.
// The server will listen on the configured port.
// The database configuration that will be used can be set by the
// INTEGRATION_DB_FLAVOR environment variable and can have the values "cockroach"
// or "postgres". Defaults to "cockroach".
//
// The deault Instance and Organisation are read from the DB and system
// users are created as needed.
//
// After the server is started, a [grpc.ClientConn] will be created and
// the server is polled for it's health status.
//
// Note: the database must already be setup and intialized before
// using NewTester. See the CONTRIBUTING.md document for details.
func NewTester(ctx context.Context) *Tester {
args := strings.Split(commandLine, " ")
sc := make(chan *start.Server)
//nolint:contextcheck
cmd := cmd.New(os.Stdout, os.Stdin, args, sc)
cmd.SetArgs(args)
err := viper.MergeConfig(bytes.NewBuffer(zitadelYAML))
logging.OnError(err).Fatal()
flavor := os.Getenv("INTEGRATION_DB_FLAVOR")
switch flavor {
case "cockroach", "":
err = viper.MergeConfig(bytes.NewBuffer(cockroachYAML))
case "postgres":
err = viper.MergeConfig(bytes.NewBuffer(postgresYAML))
default:
logging.New().WithField("flavor", flavor).Fatal("unknown db flavor set in INTEGRATION_DB_FLAVOR")
}
logging.OnError(err).Fatal()
tester := new(Tester)
tester.wg.Add(1)
go func(wg *sync.WaitGroup) {
logging.OnError(cmd.Execute()).Fatal()
wg.Done()
}(&tester.wg)
select {
case tester.Server = <-sc:
case <-ctx.Done():
logging.OnError(ctx.Err()).Fatal("waiting for integration tester server")
}
tester.createClientConn(ctx)
tester.createSystemUser(ctx)
return tester
}
func Contexts(timeout time.Duration) (ctx, errCtx context.Context, cancel context.CancelFunc) {
errCtx, cancel = context.WithCancel(context.Background())
cancel()
ctx, cancel = context.WithTimeout(context.Background(), timeout)
return ctx, errCtx, cancel
}

View File

@@ -0,0 +1,16 @@
//go:build integration
package integration
import (
"testing"
"time"
)
func TestNewTester(t *testing.T) {
ctx, _, cancel := Contexts(time.Hour)
defer cancel()
s := NewTester(ctx)
defer s.Done()
}

View File

@@ -0,0 +1,24 @@
// Code generated by "stringer -type=UserType"; DO NOT EDIT.
package integration
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unspecified-0]
_ = x[OrgOwner-1]
}
const _UserType_name = "UnspecifiedOrgOwner"
var _UserType_index = [...]uint8{0, 11, 19}
func (i UserType) String() string {
if i < 0 || i >= UserType(len(_UserType_index)-1) {
return "UserType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _UserType_name[_UserType_index[i]:_UserType_index[i+1]]
}

View File

@@ -0,0 +1,147 @@
package main
import (
"bytes"
_ "embed"
"fmt"
"io"
"os"
"text/template"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/pluginpb"
protoc_gen_zitadel "github.com/zitadel/zitadel/pkg/grpc/protoc/v2"
)
var (
//go:embed zitadel.pb.go.tmpl
zitadelTemplate []byte
)
type authMethods struct {
GoPackageName string
ProtoPackageName string
ServiceName string
AuthOptions []authOption
AuthContext []authContext
CustomHTTPResponses []httpResponse
}
type authOption struct {
Name string
Permission string
CheckFieldName string
}
type authContext struct {
Name string
OrgMethod string
}
type httpResponse struct {
Name string
Code int32
}
func main() {
input, _ := io.ReadAll(os.Stdin)
var req pluginpb.CodeGeneratorRequest
err := proto.Unmarshal(input, &req)
if err != nil {
panic(err)
}
opts := protogen.Options{}
plugin, err := opts.New(&req)
if err != nil {
panic(err)
}
plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
tmpl := loadTemplate(zitadelTemplate)
for _, file := range plugin.Files {
methods := new(authMethods)
for _, service := range file.Services {
methods.ServiceName = service.GoName
methods.GoPackageName = string(file.GoPackageName)
methods.ProtoPackageName = *file.Proto.Package
for _, method := range service.Methods {
options := method.Desc.Options().(*descriptorpb.MethodOptions)
if options == nil {
continue
}
ext := proto.GetExtension(options, protoc_gen_zitadel.E_Options).(*protoc_gen_zitadel.Options)
if ext == nil {
continue
}
if ext.AuthOption != nil {
generateAuthOption(methods, ext.AuthOption, method)
}
if ext.HttpResponse != nil {
methods.CustomHTTPResponses = append(methods.CustomHTTPResponses, httpResponse{Name: string(method.Output.Desc.Name()), Code: ext.HttpResponse.SuccessCode})
}
}
}
if len(methods.AuthOptions) > 0 {
generateFile(tmpl, methods, file, plugin)
}
}
// Generate a response from our plugin and marshall as protobuf
stdout := plugin.Response()
out, err := proto.Marshal(stdout)
if err != nil {
panic(err)
}
// Write the response to stdout, to be picked up by protoc
_, err = fmt.Fprint(os.Stdout, string(out))
if err != nil {
panic(err)
}
}
func generateAuthOption(methods *authMethods, protoAuthOption *protoc_gen_zitadel.AuthOption, method *protogen.Method) {
methods.AuthOptions = append(methods.AuthOptions, authOption{Name: string(method.Desc.Name()), Permission: protoAuthOption.Permission})
if protoAuthOption.OrgField == "" {
return
}
orgMethod := buildAuthContextField(method.Input.Fields, protoAuthOption.OrgField)
if orgMethod != "" {
methods.AuthContext = append(methods.AuthContext, authContext{Name: string(method.Input.Desc.Name()), OrgMethod: orgMethod})
}
}
func generateFile(tmpl *template.Template, methods *authMethods, protoFile *protogen.File, plugin *protogen.Plugin) {
var buffer bytes.Buffer
err := tmpl.Execute(&buffer, &methods)
if err != nil {
panic(err)
}
filename := protoFile.GeneratedFilenamePrefix + ".pb.zitadel.go"
file := plugin.NewGeneratedFile(filename, ".")
_, err = file.Write(buffer.Bytes())
if err != nil {
panic(err)
}
}
func loadTemplate(templateData []byte) *template.Template {
return template.Must(template.New("").
Parse(string(templateData)))
}
func buildAuthContextField(fields []*protogen.Field, fieldName string) string {
for _, field := range fields {
if string(field.Desc.Name()) == fieldName {
return ".Get" + field.GoName + "()"
}
}
return ""
}

View File

@@ -0,0 +1,29 @@
// Code generated by protoc-gen-zitadel. DO NOT EDIT.
package {{.GoPackageName}}
import (
"github.com/zitadel/zitadel/internal/api/authz"
{{if .AuthContext}}"github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"{{end}}
)
var {{.ServiceName}}_AuthMethods = authz.MethodMapping {
{{ range $m := .AuthOptions}}
{{$.ServiceName}}_{{$m.Name}}_FullMethodName: authz.Option{
Permission: "{{$m.Permission}}",
CheckParam: "{{$m.CheckFieldName}}",
},
{{ end}}
}
{{ range $m := .AuthContext}}
func (r *{{ $m.Name }}) OrganisationFromRequest() *object.Organisation {
return r{{$m.OrgMethod}}
}
{{ end }}
{{ range $resp := .CustomHTTPResponses}}
func (r *{{ $resp.Name }}) CustomHTTPCode() int {
return {{$resp.Code}}
}
{{ end }}

View File

@@ -180,12 +180,20 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu
return scan(row)
}
func (q *Queries) ExistsOrg(ctx context.Context, id string) (err error) {
func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = q.OrgByID(ctx, true, id)
return err
var org *Org
if id != "" {
org, err = q.OrgByID(ctx, true, id)
} else {
org, err = q.OrgByVerifiedDomain(ctx, domain)
}
if err != nil {
return "", err
}
return org.ID, nil
}
func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries) (orgs *Orgs, err error) {

View File

@@ -65,6 +65,7 @@ var (
NotificationsProjection interface{}
NotificationsQuotaProjection interface{}
DeviceAuthProjection *deviceAuthProjection
SessionProjection *sessionProjection
)
type projection interface {
@@ -141,6 +142,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
SecurityPolicyProjection = newSecurityPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["security_policies"]))
NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"]))
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
newProjectionsList()
return nil
}
@@ -237,5 +239,6 @@ func newProjectionsList() {
SecurityPolicyProjection,
NotificationPolicyProjection,
DeviceAuthProjection,
SessionProjection,
}
}

View File

@@ -0,0 +1,221 @@
package projection
import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/session"
)
const (
SessionsProjectionTable = "projections.sessions"
SessionColumnID = "id"
SessionColumnCreationDate = "creation_date"
SessionColumnChangeDate = "change_date"
SessionColumnSequence = "sequence"
SessionColumnState = "state"
SessionColumnResourceOwner = "resource_owner"
SessionColumnInstanceID = "instance_id"
SessionColumnCreator = "creator"
SessionColumnUserID = "user_id"
SessionColumnUserCheckedAt = "user_checked_at"
SessionColumnPasswordCheckedAt = "password_checked_at"
SessionColumnMetadata = "metadata"
SessionColumnTokenID = "token_id"
)
type sessionProjection struct {
crdb.StatementHandler
}
func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfig) *sessionProjection {
p := new(sessionProjection)
config.ProjectionName = SessionsProjectionTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewMultiTableCheck(
crdb.NewTable([]*crdb.Column{
crdb.NewColumn(SessionColumnID, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnCreationDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(SessionColumnChangeDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(SessionColumnSequence, crdb.ColumnTypeInt64),
crdb.NewColumn(SessionColumnState, crdb.ColumnTypeEnum),
crdb.NewColumn(SessionColumnResourceOwner, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnCreator, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnUserID, crdb.ColumnTypeText, crdb.Nullable()),
crdb.NewColumn(SessionColumnUserCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnPasswordCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnMetadata, crdb.ColumnTypeJSONB, crdb.Nullable()),
crdb.NewColumn(SessionColumnTokenID, crdb.ColumnTypeText, crdb.Nullable()),
},
crdb.NewPrimaryKey(SessionColumnInstanceID, SessionColumnID),
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
return p
}
func (p *sessionProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: session.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: session.AddedType,
Reduce: p.reduceSessionAdded,
},
{
Event: session.UserCheckedType,
Reduce: p.reduceUserChecked,
},
{
Event: session.PasswordCheckedType,
Reduce: p.reducePasswordChecked,
},
{
Event: session.TokenSetType,
Reduce: p.reduceTokenSet,
},
{
Event: session.MetadataSetType,
Reduce: p.reduceMetadataSet,
},
{
Event: session.TerminateType,
Reduce: p.reduceSessionTerminated,
},
},
},
{
Aggregate: instance.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: instance.InstanceRemovedEventType,
Reduce: reduceInstanceRemovedHelper(SMSColumnInstanceID),
},
},
},
}
}
func (p *sessionProjection) reduceSessionAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.AddedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sfrgf", "reduce.wrong.event.type %s", session.AddedType)
}
return crdb.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnID, e.Aggregate().ID),
handler.NewCol(SessionColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(SessionColumnCreationDate, e.CreationDate()),
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner),
handler.NewCol(SessionColumnState, domain.SessionStateActive),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnCreator, e.User),
},
), nil
}
func (p *sessionProjection) reduceUserChecked(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.UserCheckedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-saDg5", "reduce.wrong.event.type %s", session.UserCheckedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnUserID, e.UserID),
handler.NewCol(SessionColumnUserCheckedAt, e.CheckedAt),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reducePasswordChecked(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.PasswordCheckedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SDgrb", "reduce.wrong.event.type %s", session.PasswordCheckedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnPasswordCheckedAt, e.CheckedAt),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reduceTokenSet(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.TokenSetEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAfd3", "reduce.wrong.event.type %s", session.TokenSetType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnTokenID, e.TokenID),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reduceMetadataSet(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.MetadataSetEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAfd3", "reduce.wrong.event.type %s", session.MetadataSetType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnMetadata, e.Metadata),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reduceSessionTerminated(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.TerminateEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAftn", "reduce.wrong.event.type %s", session.TerminateType)
}
return crdb.NewDeleteStatement(
e,
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}

View File

@@ -0,0 +1,260 @@
package projection
import (
"testing"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/session"
)
func TestSessionProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "instance reduceSessionAdded",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{}`),
), session.AddedEventMapper),
},
reduce: (&sessionProjection{}).reduceSessionAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.sessions (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
anyArg{},
anyArg{},
"ro-id",
domain.SessionStateActive,
uint64(15),
"editor-user",
},
},
},
},
},
},
{
name: "instance reduceUserChecked",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{
"userId": "user-id",
"checkedAt": "2023-05-04T00:00:00Z"
}`),
), session.UserCheckedEventMapper),
},
reduce: (&sessionProjection{}).reduceUserChecked,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
"user-id",
time.Date(2023, time.May, 4, 0, 0, 0, 0, time.UTC),
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reducePasswordChecked",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{
"checkedAt": "2023-05-04T00:00:00Z"
}`),
), session.PasswordCheckedEventMapper),
},
reduce: (&sessionProjection{}).reducePasswordChecked,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
time.Date(2023, time.May, 4, 0, 0, 0, 0, time.UTC),
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceTokenSet",
args: args{
event: getEvent(testEvent(
session.TokenSetType,
session.AggregateType,
[]byte(`{
"tokenID": "tokenID"
}`),
), session.TokenSetEventMapper),
},
reduce: (&sessionProjection{}).reduceTokenSet,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
"tokenID",
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceMetadataSet",
args: args{
event: getEvent(testEvent(
session.MetadataSetType,
session.AggregateType,
[]byte(`{
"metadata": {
"key": "dmFsdWU="
}
}`),
), session.MetadataSetEventMapper),
},
reduce: (&sessionProjection{}).reduceMetadataSet,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
map[string][]byte{
"key": []byte("value"),
},
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceSessionTerminated",
args: args{
event: getEvent(testEvent(
session.TerminateType,
session.AggregateType,
[]byte(`{}`),
), session.TerminateEventMapper),
},
reduce: (&sessionProjection{}).reduceSessionTerminated,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.sessions WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceInstanceRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(instance.InstanceRemovedEventType),
instance.AggregateType,
nil,
), instance.InstanceRemovedEventMapper),
},
reduce: reduceInstanceRemovedHelper(SessionColumnInstanceID),
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.sessions WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"agg-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if !errors.IsErrorInvalidArgument(err) {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, SessionsProjectionTable, tt.want)
})
}
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
)
@@ -30,7 +31,8 @@ type Queries struct {
eventstore *eventstore.Eventstore
client *database.DB
idpConfigEncryption crypto.EncryptionAlgorithm
idpConfigEncryption crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
DefaultLanguage language.Tag
LoginDir http.FileSystem
@@ -43,7 +45,16 @@ type Queries struct {
multifactors domain.MultifactorConfigs
}
func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *database.DB, projections projection.Config, defaults sd.SystemDefaults, idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm, zitadelRoles []authz.RoleMapping) (repo *Queries, err error) {
func StartQueries(
ctx context.Context,
es *eventstore.Eventstore,
sqlClient *database.DB,
projections projection.Config,
defaults sd.SystemDefaults,
idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm,
zitadelRoles []authz.RoleMapping,
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
) (repo *Queries, err error) {
statikLoginFS, err := fs.NewWithNamespace("login")
if err != nil {
return nil, fmt.Errorf("unable to start login statik dir")
@@ -63,6 +74,7 @@ func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *dat
LoginTranslationFileContents: make(map[string][]byte),
NotificationTranslationFileContents: make(map[string][]byte),
zitadelRoles: zitadelRoles,
sessionTokenVerifier: sessionTokenVerifier,
}
iam_repo.RegisterEventMappers(repo.eventstore)
usr_repo.RegisterEventMappers(repo.eventstore)
@@ -71,6 +83,7 @@ func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *dat
action.RegisterEventMappers(repo.eventstore)
keypair.RegisterEventMappers(repo.eventstore)
usergrant.RegisterEventMappers(repo.eventstore)
session.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{

320
internal/query/session.go Normal file
View File

@@ -0,0 +1,320 @@
package query
import (
"context"
"database/sql"
errs "errors"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type Sessions struct {
SearchResponse
Sessions []*Session
}
type Session struct {
ID string
CreationDate time.Time
ChangeDate time.Time
Sequence uint64
State domain.SessionState
ResourceOwner string
Creator string
UserFactor SessionUserFactor
PasswordFactor SessionPasswordFactor
Metadata map[string][]byte
}
type SessionUserFactor struct {
UserID string
UserCheckedAt time.Time
LoginName string
DisplayName string
}
type SessionPasswordFactor struct {
PasswordCheckedAt time.Time
}
type SessionsSearchQueries struct {
SearchRequest
Queries []SearchQuery
}
func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
query = q.toQuery(query)
}
return query
}
var (
sessionsTable = table{
name: projection.SessionsProjectionTable,
instanceIDCol: projection.SessionColumnInstanceID,
}
SessionColumnID = Column{
name: projection.SessionColumnID,
table: sessionsTable,
}
SessionColumnCreationDate = Column{
name: projection.SessionColumnCreationDate,
table: sessionsTable,
}
SessionColumnChangeDate = Column{
name: projection.SessionColumnChangeDate,
table: sessionsTable,
}
SessionColumnSequence = Column{
name: projection.SessionColumnSequence,
table: sessionsTable,
}
SessionColumnState = Column{
name: projection.SessionColumnState,
table: sessionsTable,
}
SessionColumnResourceOwner = Column{
name: projection.SessionColumnResourceOwner,
table: sessionsTable,
}
SessionColumnInstanceID = Column{
name: projection.SessionColumnInstanceID,
table: sessionsTable,
}
SessionColumnCreator = Column{
name: projection.SessionColumnCreator,
table: sessionsTable,
}
SessionColumnUserID = Column{
name: projection.SessionColumnUserID,
table: sessionsTable,
}
SessionColumnUserCheckedAt = Column{
name: projection.SessionColumnUserCheckedAt,
table: sessionsTable,
}
SessionColumnPasswordCheckedAt = Column{
name: projection.SessionColumnPasswordCheckedAt,
table: sessionsTable,
}
SessionColumnMetadata = Column{
name: projection.SessionColumnMetadata,
table: sessionsTable,
}
SessionColumnToken = Column{
name: projection.SessionColumnTokenID,
table: sessionsTable,
}
)
func (q *Queries) SessionByID(ctx context.Context, id, sessionToken string) (_ *Session, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareSessionQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.Eq{
SessionColumnID.identifier(): id,
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement")
}
row := q.client.QueryRowContext(ctx, stmt, args...)
session, tokenID, err := scan(row)
if err != nil {
return nil, err
}
if sessionToken == "" {
return session, nil
}
if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil {
return nil, errors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied")
}
return session, nil
}
func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (_ *Sessions, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareSessionsQuery(ctx, q.client)
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9Jf", "Errors.Query.InvalidRequest")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil || rows.Err() != nil {
return nil, errors.ThrowInternal(err, "QUERY-Sfg42", "Errors.Internal")
}
sessions, err := scan(rows)
if err != nil {
return nil, err
}
sessions.LatestSequence, err = q.latestSequence(ctx, sessionsTable)
return sessions, err
}
func NewSessionIDsSearchQuery(ids []string) (SearchQuery, error) {
list := make([]interface{}, len(ids))
for i, value := range ids {
list[i] = value
}
return NewListQuery(SessionColumnID, list, ListIn)
}
func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) {
return NewTextQuery(SessionColumnCreator, creator, TextEquals)
}
func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) {
return sq.Select(
SessionColumnID.identifier(),
SessionColumnCreationDate.identifier(),
SessionColumnChangeDate.identifier(),
SessionColumnSequence.identifier(),
SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(),
SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnMetadata.identifier(),
SessionColumnToken.identifier(),
).From(sessionsTable.identifier()).
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Session, string, error) {
session := new(Session)
var (
userID sql.NullString
userCheckedAt sql.NullTime
loginName sql.NullString
displayName sql.NullString
passwordCheckedAt sql.NullTime
metadata database.Map[[]byte]
token sql.NullString
)
err := row.Scan(
&session.ID,
&session.CreationDate,
&session.ChangeDate,
&session.Sequence,
&session.State,
&session.ResourceOwner,
&session.Creator,
&userID,
&userCheckedAt,
&loginName,
&displayName,
&passwordCheckedAt,
&metadata,
&token,
)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, "", errors.ThrowNotFound(err, "QUERY-SFeaa", "Errors.Session.NotExisting")
}
return nil, "", errors.ThrowInternal(err, "QUERY-SAder", "Errors.Internal")
}
session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.Metadata = metadata
return session, token.String, nil
}
}
func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Sessions, error)) {
return sq.Select(
SessionColumnID.identifier(),
SessionColumnCreationDate.identifier(),
SessionColumnChangeDate.identifier(),
SessionColumnSequence.identifier(),
SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(),
SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnMetadata.identifier(),
countColumn.identifier(),
).From(sessionsTable.identifier()).
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Sessions, error) {
sessions := &Sessions{Sessions: []*Session{}}
for rows.Next() {
session := new(Session)
var (
userID sql.NullString
userCheckedAt sql.NullTime
loginName sql.NullString
displayName sql.NullString
passwordCheckedAt sql.NullTime
metadata database.Map[[]byte]
)
err := rows.Scan(
&session.ID,
&session.CreationDate,
&session.ChangeDate,
&session.Sequence,
&session.State,
&session.ResourceOwner,
&session.Creator,
&userID,
&userCheckedAt,
&loginName,
&displayName,
&passwordCheckedAt,
&metadata,
&sessions.Count,
)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SAfeg", "Errors.Internal")
}
session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.Metadata = metadata
sessions.Sessions = append(sessions.Sessions, session)
}
return sessions, nil
}
}

View File

@@ -0,0 +1,396 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions.id,` +
` projections.sessions.creation_date,` +
` projections.sessions.change_date,` +
` projections.sessions.sequence,` +
` projections.sessions.state,` +
` projections.sessions.resource_owner,` +
` projections.sessions.creator,` +
` projections.sessions.user_id,` +
` projections.sessions.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.sessions.password_checked_at,` +
` projections.sessions.metadata,` +
` projections.sessions.token_id` +
` FROM projections.sessions` +
` LEFT JOIN projections.login_names2 ON projections.sessions.user_id = projections.login_names2.user_id AND projections.sessions.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions.user_id = projections.users8_humans.user_id AND projections.sessions.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions.id,` +
` projections.sessions.creation_date,` +
` projections.sessions.change_date,` +
` projections.sessions.sequence,` +
` projections.sessions.state,` +
` projections.sessions.resource_owner,` +
` projections.sessions.creator,` +
` projections.sessions.user_id,` +
` projections.sessions.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.sessions.password_checked_at,` +
` projections.sessions.metadata,` +
` COUNT(*) OVER ()` +
` FROM projections.sessions` +
` LEFT JOIN projections.login_names2 ON projections.sessions.user_id = projections.login_names2.user_id AND projections.sessions.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions.user_id = projections.users8_humans.user_id AND projections.sessions.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
sessionCols = []string{
"id",
"creation_date",
"change_date",
"sequence",
"state",
"resource_owner",
"creator",
"user_id",
"user_checked_at",
"login_name",
"display_name",
"password_checked_at",
"metadata",
"token",
}
sessionsCols = []string{
"id",
"creation_date",
"change_date",
"sequence",
"state",
"resource_owner",
"creator",
"user_id",
"user_checked_at",
"login_name",
"display_name",
"password_checked_at",
"metadata",
"count",
}
)
func Test_SessionsPrepare(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareSessionsQuery no result",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueries(
expectedSessionsQuery,
nil,
nil,
),
},
object: &Sessions{Sessions: []*Session{}},
},
{
name: "prepareSessionQuery",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueries(
expectedSessionsQuery,
sessionsCols,
[][]driver.Value{
{
"session-id",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
},
),
},
object: &Sessions{
SearchResponse: SearchResponse{
Count: 1,
},
Sessions: []*Session{
{
ID: "session-id",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
},
},
},
{
name: "prepareSessionsQuery multiple result",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueries(
expectedSessionsQuery,
sessionsCols,
[][]driver.Value{
{
"session-id",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
{
"session-id2",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator2",
"user-id2",
testNow,
"login-name2",
"display-name2",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
},
),
},
object: &Sessions{
SearchResponse: SearchResponse{
Count: 2,
},
Sessions: []*Session{
{
ID: "session-id",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
{
ID: "session-id2",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator2",
UserFactor: SessionUserFactor{
UserID: "user-id2",
UserCheckedAt: testNow,
LoginName: "login-name2",
DisplayName: "display-name2",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
},
},
},
{
name: "prepareSessionsQuery sql err",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedSessionsQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
func Test_SessionPrepare(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareSessionQuery no result",
prepare: prepareSessionQueryTesting(t, ""),
want: want{
sqlExpectations: mockQueries(
expectedSessionQuery,
nil,
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*Session)(nil),
},
{
name: "prepareSessionQuery found",
prepare: prepareSessionQueryTesting(t, "tokenID"),
want: want{
sqlExpectations: mockQuery(
expectedSessionQuery,
sessionCols,
[]driver.Value{
"session-id",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
"tokenID",
},
),
},
object: &Session{
ID: "session-id",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
},
{
name: "prepareSessionQuery sql err",
prepare: prepareSessionQueryTesting(t, ""),
want: want{
sqlExpectations: mockQueryErr(
expectedSessionQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
func prepareSessionQueryTesting(t *testing.T, token string) func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, error)) {
return func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, error)) {
builder, scan := prepareSessionQuery(ctx, db)
return builder, func(row *sql.Row) (*Session, error) {
session, tokenID, err := scan(row)
require.Equal(t, tokenID, token)
return session, err
}
}
}

View File

@@ -0,0 +1,25 @@
package session
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
AggregateType = "session"
AggregateVersion = "v1"
)
type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(id, resourceOwner string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
ResourceOwner: resourceOwner,
},
}
}

View File

@@ -0,0 +1,12 @@
package session
import "github.com/zitadel/zitadel/internal/eventstore"
func RegisterEventMappers(es *eventstore.Eventstore) {
es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper).
RegisterFilterEventMapper(AggregateType, UserCheckedType, UserCheckedEventMapper).
RegisterFilterEventMapper(AggregateType, PasswordCheckedType, PasswordCheckedEventMapper).
RegisterFilterEventMapper(AggregateType, TokenSetType, TokenSetEventMapper).
RegisterFilterEventMapper(AggregateType, MetadataSetType, MetadataSetEventMapper).
RegisterFilterEventMapper(AggregateType, TerminateType, TerminateEventMapper)
}

View File

@@ -0,0 +1,255 @@
package session
import (
"context"
"encoding/json"
"time"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
const (
sessionEventPrefix = "session."
AddedType = sessionEventPrefix + "added"
UserCheckedType = sessionEventPrefix + "user.checked"
PasswordCheckedType = sessionEventPrefix + "password.checked"
TokenSetType = sessionEventPrefix + "token.set"
MetadataSetType = sessionEventPrefix + "metadata.set"
TerminateType = sessionEventPrefix + "terminated"
)
type AddedEvent struct {
eventstore.BaseEvent `json:"-"`
}
func (e *AddedEvent) Data() interface{} {
return e
}
func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
) *AddedEvent {
return &AddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
AddedType,
),
}
}
func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &AddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "SESSION-DG4gn", "unable to unmarshal session added")
}
return added, nil
}
type UserCheckedEvent struct {
eventstore.BaseEvent `json:"-"`
UserID string `json:"userID"`
CheckedAt time.Time `json:"checkedAt"`
}
func (e *UserCheckedEvent) Data() interface{} {
return e
}
func (e *UserCheckedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewUserCheckedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
userID string,
checkedAt time.Time,
) *UserCheckedEvent {
return &UserCheckedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
UserCheckedType,
),
UserID: userID,
CheckedAt: checkedAt,
}
}
func UserCheckedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &UserCheckedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "SESSION-DSGn5", "unable to unmarshal user checked")
}
return added, nil
}
type PasswordCheckedEvent struct {
eventstore.BaseEvent `json:"-"`
CheckedAt time.Time `json:"checkedAt"`
}
func (e *PasswordCheckedEvent) Data() interface{} {
return e
}
func (e *PasswordCheckedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewPasswordCheckedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
checkedAt time.Time,
) *PasswordCheckedEvent {
return &PasswordCheckedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
PasswordCheckedType,
),
CheckedAt: checkedAt,
}
}
func PasswordCheckedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &PasswordCheckedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "SESSION-DGt21", "unable to unmarshal password checked")
}
return added, nil
}
type TokenSetEvent struct {
eventstore.BaseEvent `json:"-"`
TokenID string `json:"tokenID"`
}
func (e *TokenSetEvent) Data() interface{} {
return e
}
func (e *TokenSetEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewTokenSetEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
tokenID string,
) *TokenSetEvent {
return &TokenSetEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
TokenSetType,
),
TokenID: tokenID,
}
}
func TokenSetEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &TokenSetEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "SESSION-Sf3va", "unable to unmarshal token set")
}
return added, nil
}
type MetadataSetEvent struct {
eventstore.BaseEvent `json:"-"`
Metadata map[string][]byte `json:"metadata"`
}
func (e *MetadataSetEvent) Data() interface{} {
return e
}
func (e *MetadataSetEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewMetadataSetEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
metadata map[string][]byte,
) *MetadataSetEvent {
return &MetadataSetEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
MetadataSetType,
),
Metadata: metadata,
}
}
func MetadataSetEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &MetadataSetEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "SESSION-BD21d", "unable to unmarshal metadata set")
}
return added, nil
}
type TerminateEvent struct {
eventstore.BaseEvent `json:"-"`
}
func (e *TerminateEvent) Data() interface{} {
return e
}
func (e *TerminateEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewTerminateEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
) *TerminateEvent {
return &TerminateEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
TerminateType,
),
}
}
func TerminateEventMapper(event *repository.Event) (eventstore.Event, error) {
return &TerminateEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}, nil
}

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: Das Speichern des Action Logs in der Datenbank ist fehlgeschlagen
ScanFailed: Das Abfragen der verbrauchten Actions Sekunden ist fehlgeschlagen
Session:
NotExisting: Session existiert nicht
Terminated: Session bereits beendet
Token:
Invalid: Session Token ist ungültig
AggregateTypes:
action: Action

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: Storing action execution log to database failed
ScanFailed: Querying usage for action execution seconds failed
Session:
NotExisting: Session does not exist
Terminated: Session already terminated
Token:
Invalid: Session Token is invalid
AggregateTypes:
action: Action

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: Ha fallado el almacenaje del registro de ejecución de acciones en la base de datos
ScanFailed: La consulta de uso de los segundos de ejecuciónde acciones ha fallado
Session:
NotExisting: La sesión no existe
Terminated: Sesión ya terminada
Token:
Invalid: El identificador de sesión no es válido
AggregateTypes:
action: Acción

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: L'enregistrement du journal d'action dans la base de données a échoué
ScanFailed: L'interrogation des secondes d'action consommées a échoué
Session:
NotExisting: La session n'existe pas
Terminated: La session est déjà terminée
Token:
Invalid: Le jeton de session n'est pas valide
AggregateTypes:
action: Action

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: Il salvataggio del registro delle azioni nel database non è riuscito
ScanFailed: La query dei secondi delle azioni utilizzate non è riuscita
Session:
NotExisting: La sessione non esiste
Terminated: Sessione già terminata
Token:
Invalid: Il token della sessione non è valido
AggregateTypes:
action: Azione

View File

@@ -459,6 +459,11 @@ Errors:
Execution:
StorageFailed: アクション実行ログのデータベースへの保存に失敗しました
ScanFailed: アクション実行時間を取得する使用状況クエリに失敗しました
Session:
NotExisting: セッションが存在しない
Terminated: セッションはすでに終了しています
Token:
Invalid: セッショントークンが無効です
AggregateTypes:
action: アクション

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: Zapisywanie dziennika wykonania akcji do bazy danych nie powiodło się
ScanFailed: Zapytanie o użycie dla sekund wykonania akcji nie powiodło się
Session:
NotExisting: Sesja nie istnieje
Terminated: Sesja już zakończona
Token:
Invalid: Token sesji jest nieprawidłowy
AggregateTypes:
action: Działanie

View File

@@ -470,6 +470,11 @@ Errors:
Execution:
StorageFailed: 将行动执行日志存储到数据库失败
ScanFailed: Q查询动作执行秒数的使用情况失败
Session:
NotExisting: 会话不存在
Terminated: 会话已经终止
Token:
Invalid: 会话令牌是无效的
AggregateTypes:
action: 动作