package session import ( "context" "net" "net/http" "time" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/muhlemmer/gu" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query" objpb "github.com/zitadel/zitadel/pkg/grpc/object" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" ) var ( timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, } ) func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) if err != nil { return nil, err } return &session.GetSessionResponse{ Session: sessionToPb(res), }, nil } func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { queries, err := listSessionsRequestToQuery(ctx, req) if err != nil { return nil, err } sessions, err := s.query.SearchSessions(ctx, queries) if err != nil { return nil, err } return &session.ListSessionsResponse{ Details: object.ToListDetails(sessions.SearchResponse), Sessions: sessionsToPb(sessions.Sessions), }, nil } func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req) if err != nil { return nil, err } challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) if err != nil { return nil, err } set, err := s.command.CreateSession(ctx, cmds, metadata, userAgent, lifetime) if err != nil { return nil, err } return &session.CreateSessionResponse{ Details: object.DomainToDetailsPb(set.ObjectDetails), SessionId: set.ID, SessionToken: set.NewToken, Challenges: challengeResponse, }, nil } func (s *Server) SetSession(ctx context.Context, req *session.SetSessionRequest) (*session.SetSessionResponse, error) { checks, err := s.setSessionRequestToCommand(ctx, req) if err != nil { return nil, err } challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) if err != nil { return nil, err } set, err := s.command.UpdateSession(ctx, req.GetSessionId(), req.GetSessionToken(), cmds, req.GetMetadata(), req.GetLifetime().AsDuration()) if err != nil { return nil, err } // if there's no new token, just return the current if set.NewToken == "" { set.NewToken = req.GetSessionToken() } return &session.SetSessionResponse{ Details: object.DomainToDetailsPb(set.ObjectDetails), SessionToken: set.NewToken, Challenges: challengeResponse, }, nil } func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRequest) (*session.DeleteSessionResponse, error) { details, err := s.command.TerminateSession(ctx, req.GetSessionId(), req.GetSessionToken()) if err != nil { return nil, err } return &session.DeleteSessionResponse{ Details: object.DomainToDetailsPb(details), }, nil } func sessionsToPb(sessions []*query.Session) []*session.Session { s := make([]*session.Session, len(sessions)) for i, session := range sessions { s[i] = sessionToPb(session) } return s } func sessionToPb(s *query.Session) *session.Session { return &session.Session{ Id: s.ID, CreationDate: timestamppb.New(s.CreationDate), ChangeDate: timestamppb.New(s.ChangeDate), Sequence: s.Sequence, Factors: factorsToPb(s), Metadata: s.Metadata, UserAgent: userAgentToPb(s.UserAgent), ExpirationDate: expirationToPb(s.Expiration), } } func userAgentToPb(ua domain.UserAgent) *session.UserAgent { if ua.IsEmpty() { return nil } out := &session.UserAgent{ FingerprintId: ua.FingerprintID, Description: ua.Description, } if ua.IP != nil { out.Ip = gu.Ptr(ua.IP.String()) } if ua.Header == nil { return out } out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) for k, v := range ua.Header { out.Header[k] = &session.UserAgent_HeaderValues{ Values: v, } } return out } func expirationToPb(expiration time.Time) *timestamppb.Timestamp { if expiration.IsZero() { return nil } return timestamppb.New(expiration) } func factorsToPb(s *query.Session) *session.Factors { user := userFactorToPb(s.UserFactor) if user == nil { return nil } return &session.Factors{ User: user, Password: passwordFactorToPb(s.PasswordFactor), WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), Intent: intentFactorToPb(s.IntentFactor), Totp: totpFactorToPb(s.TOTPFactor), OtpSms: otpFactorToPb(s.OTPSMSFactor), OtpEmail: otpFactorToPb(s.OTPEmailFactor), } } func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { if factor.PasswordCheckedAt.IsZero() { return nil } return &session.PasswordFactor{ VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), } } func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { if factor.IntentCheckedAt.IsZero() { return nil } return &session.IntentFactor{ VerifiedAt: timestamppb.New(factor.IntentCheckedAt), } } func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { if factor.WebAuthNCheckedAt.IsZero() { return nil } return &session.WebAuthNFactor{ VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), UserVerified: factor.UserVerified, } } func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { if factor.TOTPCheckedAt.IsZero() { return nil } return &session.TOTPFactor{ VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), } } func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { if factor.OTPCheckedAt.IsZero() { return nil } return &session.OTPFactor{ VerifiedAt: timestamppb.New(factor.OTPCheckedAt), } } func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { if factor.UserID == "" || factor.UserCheckedAt.IsZero() { return nil } return &session.UserFactor{ VerifiedAt: timestamppb.New(factor.UserCheckedAt), Id: factor.UserID, LoginName: factor.LoginName, DisplayName: factor.DisplayName, OrganisationId: factor.ResourceOwner, OrganizationId: factor.ResourceOwner, } } func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { offset, limit, asc := object.ListQueryToQuery(req.Query) queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) if err != nil { return nil, err } return &query.SessionsSearchQueries{ SearchRequest: query.SearchRequest{ Offset: offset, Limit: limit, Asc: asc, SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), }, Queries: queries, }, nil } func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { q := make([]query.SearchQuery, len(queries)+1) for i, v := range queries { q[i], err = sessionQueryToQuery(v) if err != nil { return nil, err } } creatorQuery, err := query.NewSessionCreatorSearchQuery(authz.GetCtxData(ctx).UserID) if err != nil { return nil, err } q[len(queries)] = creatorQuery return q, nil } func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) { switch q := sq.Query.(type) { case *session.SearchQuery_IdsQuery: return idsQueryToQuery(q.IdsQuery) case *session.SearchQuery_UserIdQuery: return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) case *session.SearchQuery_CreationDateQuery: return creationDateQueryToQuery(q.CreationDateQuery) default: return nil, caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") } } func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { return query.NewSessionIDsSearchQuery(q.Ids) } func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { comparison := timestampComparisons[q.GetMethod()] return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) } func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { switch field { case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: return query.SessionColumnCreationDate default: return query.Column{} } } func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { return nil, nil, nil, 0, err } return checks, req.GetMetadata(), userAgentToCommand(req.GetUserAgent()), req.GetLifetime().AsDuration(), nil } func userAgentToCommand(userAgent *session.UserAgent) *domain.UserAgent { if userAgent == nil { return nil } out := &domain.UserAgent{ FingerprintID: userAgent.FingerprintId, IP: net.ParseIP(userAgent.GetIp()), Description: userAgent.Description, } if len(userAgent.Header) > 0 { out.Header = make(http.Header, len(userAgent.Header)) for k, values := range userAgent.Header { out.Header[k] = values.GetValues() } } return out } func (s *Server) setSessionRequestToCommand(ctx context.Context, req *session.SetSessionRequest) ([]command.SessionCommand, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { return nil, err } return checks, nil } func (s *Server) checksToCommand(ctx context.Context, checks *session.Checks) ([]command.SessionCommand, error) { checkUser, err := userCheck(checks.GetUser()) if err != nil { return nil, err } sessionChecks := make([]command.SessionCommand, 0, 7) if checkUser != nil { user, err := checkUser.search(ctx, s.query) if err != nil { return nil, err } sessionChecks = append(sessionChecks, command.CheckUser(user.ID, user.ResourceOwner)) } if password := checks.GetPassword(); password != nil { sessionChecks = append(sessionChecks, command.CheckPassword(password.GetPassword())) } if intent := checks.GetIdpIntent(); intent != nil { sessionChecks = append(sessionChecks, command.CheckIntent(intent.GetIdpIntentId(), intent.GetIdpIntentToken())) } if passkey := checks.GetWebAuthN(); passkey != nil { sessionChecks = append(sessionChecks, s.command.CheckWebAuthN(passkey.GetCredentialAssertionData())) } if totp := checks.GetTotp(); totp != nil { sessionChecks = append(sessionChecks, command.CheckTOTP(totp.GetCode())) } if otp := checks.GetOtpSms(); otp != nil { sessionChecks = append(sessionChecks, command.CheckOTPSMS(otp.GetCode())) } if otp := checks.GetOtpEmail(); otp != nil { sessionChecks = append(sessionChecks, command.CheckOTPEmail(otp.GetCode())) } return sessionChecks, nil } func (s *Server) challengesToCommand(challenges *session.RequestChallenges, cmds []command.SessionCommand) (*session.Challenges, []command.SessionCommand, error) { if challenges == nil { return nil, cmds, nil } resp := new(session.Challenges) if req := challenges.GetWebAuthN(); req != nil { challenge, cmd := s.createWebAuthNChallengeCommand(req) resp.WebAuthN = challenge cmds = append(cmds, cmd) } if req := challenges.GetOtpSms(); req != nil { challenge, cmd := s.createOTPSMSChallengeCommand(req) resp.OtpSms = challenge cmds = append(cmds, cmd) } if req := challenges.GetOtpEmail(); req != nil { challenge, cmd, err := s.createOTPEmailChallengeCommand(req) if err != nil { return nil, nil, err } resp.OtpEmail = challenge cmds = append(cmds, cmd) } return resp, cmds, nil } func (s *Server) createWebAuthNChallengeCommand(req *session.RequestChallenges_WebAuthN) (*session.Challenges_WebAuthN, command.SessionCommand) { challenge := &session.Challenges_WebAuthN{ PublicKeyCredentialRequestOptions: new(structpb.Struct), } userVerification := userVerificationRequirementToDomain(req.GetUserVerificationRequirement()) return challenge, s.command.CreateWebAuthNChallenge(userVerification, req.GetDomain(), challenge.PublicKeyCredentialRequestOptions) } func userVerificationRequirementToDomain(req session.UserVerificationRequirement) domain.UserVerificationRequirement { switch req { case session.UserVerificationRequirement_USER_VERIFICATION_REQUIREMENT_UNSPECIFIED: return domain.UserVerificationRequirementUnspecified case session.UserVerificationRequirement_USER_VERIFICATION_REQUIREMENT_REQUIRED: return domain.UserVerificationRequirementRequired case session.UserVerificationRequirement_USER_VERIFICATION_REQUIREMENT_PREFERRED: return domain.UserVerificationRequirementPreferred case session.UserVerificationRequirement_USER_VERIFICATION_REQUIREMENT_DISCOURAGED: return domain.UserVerificationRequirementDiscouraged default: return domain.UserVerificationRequirementUnspecified } } func (s *Server) createOTPSMSChallengeCommand(req *session.RequestChallenges_OTPSMS) (*string, command.SessionCommand) { if req.GetReturnCode() { challenge := new(string) return challenge, s.command.CreateOTPSMSChallengeReturnCode(challenge) } return nil, s.command.CreateOTPSMSChallenge() } func (s *Server) createOTPEmailChallengeCommand(req *session.RequestChallenges_OTPEmail) (*string, command.SessionCommand, error) { switch t := req.GetDeliveryType().(type) { case *session.RequestChallenges_OTPEmail_SendCode_: cmd, err := s.command.CreateOTPEmailChallengeURLTemplate(t.SendCode.GetUrlTemplate()) if err != nil { return nil, nil, err } return nil, cmd, nil case *session.RequestChallenges_OTPEmail_ReturnCode_: challenge := new(string) return challenge, s.command.CreateOTPEmailChallengeReturnCode(challenge), nil case nil: return nil, s.command.CreateOTPEmailChallenge(), nil default: return nil, nil, caos_errs.ThrowUnimplementedf(nil, "SESSION-k3ng0", "delivery_type oneOf %T in OTPEmailChallenge not implemented", t) } } func userCheck(user *session.CheckUser) (userSearch, error) { if user == nil { return nil, nil } switch s := user.GetSearch().(type) { case *session.CheckUser_UserId: return userByID(s.UserId), nil case *session.CheckUser_LoginName: return userByLoginName(s.LoginName) default: return nil, caos_errs.ThrowUnimplementedf(nil, "SESSION-d3b4g0", "user search %T not implemented", s) } } type userSearch interface { search(ctx context.Context, q *query.Queries) (*query.User, error) } func userByID(userID string) userSearch { return userSearchByID{userID} } func userByLoginName(loginName string) (userSearch, error) { return userSearchByLoginName{loginName}, nil } type userSearchByID struct { id string } func (u userSearchByID) search(ctx context.Context, q *query.Queries) (*query.User, error) { return q.GetUserByID(ctx, true, u.id) } type userSearchByLoginName struct { loginName string } func (u userSearchByLoginName) search(ctx context.Context, q *query.Queries) (*query.User, error) { return q.GetUserByLoginName(ctx, true, u.loginName) }