Merge branch 'main' into perf-introspecion

This commit is contained in:
Tim Möhlmann
2023-11-13 18:16:32 +02:00
186 changed files with 11778 additions and 1466 deletions

View File

@@ -60,7 +60,6 @@ func eventRequestToFilter(ctx context.Context, req *admin_pb.ListEventsRequest)
AwaitOpenTransactions().
ResourceOwner(req.ResourceOwner).
EditorUser(req.EditorUserId).
CreationDateAfter(req.CreationDate.AsTime()).
SequenceGreater(req.Sequence)
if len(aggregateIDs) > 0 || len(aggregateTypes) > 0 || len(eventTypes) > 0 {
@@ -71,8 +70,11 @@ func eventRequestToFilter(ctx context.Context, req *admin_pb.ListEventsRequest)
Builder()
}
if req.Asc {
if req.GetAsc() {
builder.OrderAsc()
builder.CreationDateAfter(req.CreationDate.AsTime())
} else {
builder.CreationDateBefore(req.CreationDate.AsTime())
}
return builder, nil

View File

@@ -1,7 +1,7 @@
package event
import (
structpb "github.com/golang/protobuf/ptypes/struct"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/errors"

View File

@@ -24,7 +24,7 @@ import (
func ListUsersRequestToModel(req *mgmt_pb.ListUsersRequest) (*query.UserSearchQueries, error) {
offset, limit, asc := object.ListQueryToModel(req.Query)
queries, err := user_grpc.UserQueriesToQuery(req.Queries)
queries, err := user_grpc.UserQueriesToQuery(req.Queries, 0 /*start from level 0*/)
if err != nil {
return nil, err
}

View File

@@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
@@ -9,6 +10,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/zitadel/logging"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/protobuf/encoding/protojson"
@@ -89,10 +91,11 @@ func CreateGatewayWithPrefix(
http1HostName string,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
tlsConfig *tls.Config,
) (http.Handler, string, error) {
runtimeMux := runtime.NewServeMux(serveMuxOptions...)
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
grpc.WithUnaryInterceptor(client_middleware.DefaultTracingClient()),
}
connection, err := dial(ctx, port, opts)
@@ -106,11 +109,17 @@ func CreateGatewayWithPrefix(
return addInterceptors(runtimeMux, http1HostName, accessInterceptor, queries), g.GatewayPathPrefix(), nil
}
func CreateGateway(ctx context.Context, port uint16, http1HostName string, accessInterceptor *http_mw.AccessInterceptor) (*Gateway, error) {
func CreateGateway(
ctx context.Context,
port uint16,
http1HostName string,
accessInterceptor *http_mw.AccessInterceptor,
tlsConfig *tls.Config,
) (*Gateway, error) {
connection, err := dial(ctx,
port,
[]grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
grpc.WithUnaryInterceptor(client_middleware.DefaultTracingClient()),
})
if err != nil {
@@ -217,3 +226,15 @@ func (r *cookieResponseWriter) WriteHeader(status int) {
}
r.ResponseWriter.WriteHeader(status)
}
func grpcCredentials(tlsConfig *tls.Config) credentials.TransportCredentials {
creds := insecure.NewCredentials()
if tlsConfig != nil {
tlsConfigClone := tlsConfig.Clone()
// We don't want to verify the certificate of the internal grpc server
// That's up to the client who called the gRPC gateway
tlsConfigClone.InsecureSkipVerify = true
creds = credentials.NewTLS(tlsConfigClone)
}
return creds
}

View File

@@ -3,12 +3,11 @@ package server
import (
"context"
"github.com/golang/protobuf/ptypes/empty"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/zitadel/logging"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/proto"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@@ -22,24 +21,23 @@ func NewValidator(validations map[string]ValidationFunction) *Validator {
return &Validator{validations: validations}
}
func (v *Validator) Healthz(_ context.Context, e *empty.Empty) (*empty.Empty, error) {
func (v *Validator) Healthz(_ context.Context, e *emptypb.Empty) (*emptypb.Empty, error) {
return e, nil
}
func (v *Validator) Ready(ctx context.Context, e *empty.Empty) (*empty.Empty, error) {
func (v *Validator) Ready(ctx context.Context, e *emptypb.Empty) (*emptypb.Empty, error) {
if len(validate(ctx, v.validations)) == 0 {
return e, nil
}
return nil, errors.ThrowInternal(nil, "API-2jD9a", "not ready")
}
func (v *Validator) Validate(ctx context.Context, _ *empty.Empty) (*structpb.Struct, error) {
validations := validate(ctx, v.validations)
return proto.ToPBStruct(validations)
func (v *Validator) Validate(ctx context.Context, _ *emptypb.Empty) (*structpb.Struct, error) {
return structpb.NewStruct(validate(ctx, v.validations))
}
func validate(ctx context.Context, validations map[string]ValidationFunction) map[string]error {
errors := make(map[string]error)
func validate(ctx context.Context, validations map[string]ValidationFunction) map[string]any {
errors := make(map[string]any)
for id, validation := range validations {
if err := validation(ctx); err != nil {
logging.Log("API-vf823").WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Error("validation failed")

View File

@@ -5,7 +5,7 @@ import (
"reflect"
"testing"
"github.com/golang/protobuf/ptypes/empty"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/zitadel/zitadel/internal/errors"
)
@@ -15,7 +15,7 @@ func TestValidator_Healthz(t *testing.T) {
validations map[string]ValidationFunction
}
type res struct {
want *empty.Empty
want *emptypb.Empty
hasErr bool
}
tests := []struct {
@@ -27,7 +27,7 @@ func TestValidator_Healthz(t *testing.T) {
"ok",
fields{},
res{
&empty.Empty{},
&emptypb.Empty{},
false,
},
},
@@ -37,7 +37,7 @@ func TestValidator_Healthz(t *testing.T) {
v := &Validator{
validations: tt.fields.validations,
}
got, err := v.Healthz(nil, &empty.Empty{})
got, err := v.Healthz(context.Background(), &emptypb.Empty{})
if (err != nil) != tt.res.hasErr {
t.Errorf("Healthz() error = %v, wantErr %v", err, tt.res.hasErr)
return
@@ -54,7 +54,7 @@ func TestValidator_Ready(t *testing.T) {
validations map[string]ValidationFunction
}
type res struct {
want *empty.Empty
want *emptypb.Empty
hasErr bool
}
tests := []struct {
@@ -82,7 +82,7 @@ func TestValidator_Ready(t *testing.T) {
},
}},
res{
&empty.Empty{},
&emptypb.Empty{},
false,
},
},
@@ -92,7 +92,7 @@ func TestValidator_Ready(t *testing.T) {
v := &Validator{
validations: tt.fields.validations,
}
got, err := v.Ready(context.Background(), &empty.Empty{})
got, err := v.Ready(context.Background(), &emptypb.Empty{})
if (err != nil) != tt.res.hasErr {
t.Errorf("Ready() error = %v, wantErr %v", err, tt.res.hasErr)
return
@@ -109,7 +109,7 @@ func Test_validate(t *testing.T) {
validations map[string]ValidationFunction
}
type res struct {
want map[string]error
want map[string]any
}
tests := []struct {
name string
@@ -126,7 +126,7 @@ func Test_validate(t *testing.T) {
},
},
res{
map[string]error{},
map[string]any{},
},
},
{
@@ -142,7 +142,7 @@ func Test_validate(t *testing.T) {
},
},
res{
map[string]error{
map[string]any{
"error": errors.ThrowInternal(nil, "id", "message"),
},
},

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"time"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -17,9 +18,20 @@ import (
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
objpb "github.com/zitadel/zitadel/pkg/grpc/object"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
)
var (
timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals,
}
)
func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) {
res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken())
if err != nil {
@@ -46,7 +58,7 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ
}
func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) {
checks, metadata, userAgent, err := s.createSessionRequestToCommand(ctx, req)
checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req)
if err != nil {
return nil, err
}
@@ -55,7 +67,7 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe
return nil, err
}
set, err := s.command.CreateSession(ctx, cmds, metadata, userAgent)
set, err := s.command.CreateSession(ctx, cmds, metadata, userAgent, lifetime)
if err != nil {
return nil, err
}
@@ -78,7 +90,7 @@ func (s *Server) SetSession(ctx context.Context, req *session.SetSessionRequest)
return nil, err
}
set, err := s.command.UpdateSession(ctx, req.GetSessionId(), req.GetSessionToken(), cmds, req.GetMetadata())
set, err := s.command.UpdateSession(ctx, req.GetSessionId(), req.GetSessionToken(), cmds, req.GetMetadata(), req.GetLifetime().AsDuration())
if err != nil {
return nil, err
}
@@ -113,13 +125,14 @@ func sessionsToPb(sessions []*query.Session) []*session.Session {
func sessionToPb(s *query.Session) *session.Session {
return &session.Session{
Id: s.ID,
CreationDate: timestamppb.New(s.CreationDate),
ChangeDate: timestamppb.New(s.ChangeDate),
Sequence: s.Sequence,
Factors: factorsToPb(s),
Metadata: s.Metadata,
UserAgent: userAgentToPb(s.UserAgent),
Id: s.ID,
CreationDate: timestamppb.New(s.CreationDate),
ChangeDate: timestamppb.New(s.ChangeDate),
Sequence: s.Sequence,
Factors: factorsToPb(s),
Metadata: s.Metadata,
UserAgent: userAgentToPb(s.UserAgent),
ExpirationDate: expirationToPb(s.Expiration),
}
}
@@ -147,6 +160,13 @@ func userAgentToPb(ua domain.UserAgent) *session.UserAgent {
return out
}
func expirationToPb(expiration time.Time) *timestamppb.Timestamp {
if expiration.IsZero() {
return nil
}
return timestamppb.New(expiration)
}
func factorsToPb(s *query.Session) *session.Factors {
user := userFactorToPb(s.UserFactor)
if user == nil {
@@ -231,9 +251,10 @@ func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRe
}
return &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: offset,
Limit: limit,
Asc: asc,
Offset: offset,
Limit: limit,
Asc: asc,
SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()),
},
Queries: queries,
}, nil
@@ -241,8 +262,8 @@ func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRe
func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries)+1)
for i, query := range queries {
q[i], err = sessionQueryToQuery(query)
for i, v := range queries {
q[i], err = sessionQueryToQuery(v)
if err != nil {
return nil, err
}
@@ -255,10 +276,14 @@ func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery)
return q, nil
}
func sessionQueryToQuery(query *session.SearchQuery) (query.SearchQuery, error) {
switch q := query.Query.(type) {
func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) {
switch q := sq.Query.(type) {
case *session.SearchQuery_IdsQuery:
return idsQueryToQuery(q.IdsQuery)
case *session.SearchQuery_UserIdQuery:
return query.NewUserIDSearchQuery(q.UserIdQuery.GetId())
case *session.SearchQuery_CreationDateQuery:
return creationDateQueryToQuery(q.CreationDateQuery)
default:
return nil, caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid")
}
@@ -268,12 +293,26 @@ func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) {
return query.NewSessionIDsSearchQuery(q.Ids)
}
func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, error) {
func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) {
comparison := timestampComparisons[q.GetMethod()]
return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison)
}
func fieldNameToSessionColumn(field session.SessionFieldName) query.Column {
switch field {
case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE:
return query.SessionColumnCreationDate
default:
return query.Column{}
}
}
func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) {
checks, err := s.checksToCommand(ctx, req.Checks)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, 0, err
}
return checks, req.GetMetadata(), userAgentToCommand(req.GetUserAgent()), nil
return checks, req.GetMetadata(), userAgentToCommand(req.GetUserAgent()), req.GetLifetime().AsDuration(), nil
}
func userAgentToCommand(userAgent *session.UserAgent) *domain.UserAgent {

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/zitadel/zitadel/internal/integration"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta"
@@ -54,7 +55,7 @@ func TestMain(m *testing.M) {
}())
}
func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, factors ...wantFactor) *session.Session {
func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, factors ...wantFactor) *session.Session {
t.Helper()
require.NotEmpty(t, id)
require.NotEmpty(t, token)
@@ -75,6 +76,11 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo
if !proto.Equal(userAgent, s.GetUserAgent()) {
t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent)
}
if expirationWindow == 0 {
assert.Nil(t, s.GetExpirationDate())
} else {
assert.WithinRange(t, s.GetExpirationDate().AsTime(), time.Now().Add(-expirationWindow), time.Now().Add(expirationWindow))
}
verifyFactors(t, s.GetFactors(), window, factors)
return s
@@ -137,12 +143,13 @@ func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration,
func TestServer_CreateSession(t *testing.T) {
tests := []struct {
name string
req *session.CreateSessionRequest
want *session.CreateSessionResponse
wantErr bool
wantFactors []wantFactor
wantUserAgent *session.UserAgent
name string
req *session.CreateSessionRequest
want *session.CreateSessionResponse
wantErr bool
wantFactors []wantFactor
wantUserAgent *session.UserAgent
wantExpirationWindow time.Duration
}{
{
name: "empty session",
@@ -182,6 +189,27 @@ func TestServer_CreateSession(t *testing.T) {
},
},
},
{
name: "negative lifetime",
req: &session.CreateSessionRequest{
Metadata: map[string][]byte{"foo": []byte("bar")},
Lifetime: durationpb.New(-5 * time.Minute),
},
wantErr: true,
},
{
name: "lifetime",
req: &session.CreateSessionRequest{
Metadata: map[string][]byte{"foo": []byte("bar")},
Lifetime: durationpb.New(5 * time.Minute),
},
want: &session.CreateSessionResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
wantExpirationWindow: 5 * time.Minute,
},
{
name: "with user",
req: &session.CreateSessionRequest{
@@ -253,7 +281,7 @@ func TestServer_CreateSession(t *testing.T) {
require.NoError(t, err)
integration.AssertDetails(t, tt.want, got)
verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantFactors...)
verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantExpirationWindow, tt.wantFactors...)
})
}
}
@@ -276,7 +304,7 @@ func TestServer_CreateSession_webauthn(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
assertionData, err := Tester.WebAuthN.CreateAssertionResponse(createResp.GetChallenges().GetWebAuthN().GetPublicKeyCredentialRequestOptions(), true)
require.NoError(t, err)
@@ -292,7 +320,7 @@ func TestServer_CreateSession_webauthn(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactorUserVerified)
verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantWebAuthNFactorUserVerified)
}
func TestServer_CreateSession_successfulIntent(t *testing.T) {
@@ -308,7 +336,7 @@ func TestServer_CreateSession_successfulIntent(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
intentID, token, _, _ := Tester.CreateSuccessfulOAuthIntent(t, idpID, User.GetUserId(), "id")
updateResp, err := Client.SetSession(CTX, &session.SetSessionRequest{
@@ -322,7 +350,7 @@ func TestServer_CreateSession_successfulIntent(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantIntentFactor)
verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantIntentFactor)
}
func TestServer_CreateSession_successfulIntentUnknownUserID(t *testing.T) {
@@ -338,7 +366,7 @@ func TestServer_CreateSession_successfulIntentUnknownUserID(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
idpUserID := "id"
intentID, token, _, _ := Tester.CreateSuccessfulOAuthIntent(t, idpID, "", idpUserID)
@@ -365,7 +393,7 @@ func TestServer_CreateSession_successfulIntentUnknownUserID(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantIntentFactor)
verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantIntentFactor)
}
func TestServer_CreateSession_startedIntentFalseToken(t *testing.T) {
@@ -381,7 +409,7 @@ func TestServer_CreateSession_startedIntentFalseToken(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
intentID := Tester.CreateIntent(t, idpID)
_, err = Client.SetSession(CTX, &session.SetSessionRequest{
@@ -433,7 +461,7 @@ func TestServer_SetSession_flow(t *testing.T) {
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{})
require.NoError(t, err)
sessionToken := createResp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, createResp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, createResp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
t.Run("check user", func(t *testing.T) {
resp, err := Client.SetSession(CTX, &session.SetSessionRequest{
@@ -449,7 +477,7 @@ func TestServer_SetSession_flow(t *testing.T) {
})
require.NoError(t, err)
sessionToken = resp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor)
})
t.Run("check webauthn, user verified (passkey)", func(t *testing.T) {
@@ -464,7 +492,7 @@ func TestServer_SetSession_flow(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
sessionToken = resp.GetSessionToken()
assertionData, err := Tester.WebAuthN.CreateAssertionResponse(resp.GetChallenges().GetWebAuthN().GetPublicKeyCredentialRequestOptions(), true)
@@ -481,7 +509,7 @@ func TestServer_SetSession_flow(t *testing.T) {
})
require.NoError(t, err)
sessionToken = resp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactorUserVerified)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantWebAuthNFactorUserVerified)
})
userAuthCtx := Tester.WithAuthorizationToken(CTX, sessionToken)
@@ -508,7 +536,7 @@ func TestServer_SetSession_flow(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
sessionToken = resp.GetSessionToken()
assertionData, err := Tester.WebAuthN.CreateAssertionResponse(resp.GetChallenges().GetWebAuthN().GetPublicKeyCredentialRequestOptions(), false)
@@ -525,7 +553,7 @@ func TestServer_SetSession_flow(t *testing.T) {
})
require.NoError(t, err)
sessionToken = resp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantWebAuthNFactor)
})
}
})
@@ -544,7 +572,7 @@ func TestServer_SetSession_flow(t *testing.T) {
})
require.NoError(t, err)
sessionToken = resp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor, wantTOTPFactor)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantWebAuthNFactor, wantTOTPFactor)
})
t.Run("check OTP SMS", func(t *testing.T) {
@@ -556,7 +584,7 @@ func TestServer_SetSession_flow(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
sessionToken = resp.GetSessionToken()
otp := resp.GetChallenges().GetOtpSms()
@@ -573,7 +601,7 @@ func TestServer_SetSession_flow(t *testing.T) {
})
require.NoError(t, err)
sessionToken = resp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor, wantOTPSMSFactor)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantWebAuthNFactor, wantOTPSMSFactor)
})
t.Run("check OTP Email", func(t *testing.T) {
@@ -587,7 +615,7 @@ func TestServer_SetSession_flow(t *testing.T) {
},
})
require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil)
verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0)
sessionToken = resp.GetSessionToken()
otp := resp.GetChallenges().GetOtpEmail()
@@ -604,10 +632,34 @@ func TestServer_SetSession_flow(t *testing.T) {
})
require.NoError(t, err)
sessionToken = resp.GetSessionToken()
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor, wantOTPEmailFactor)
verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, 0, wantUserFactor, wantWebAuthNFactor, wantOTPEmailFactor)
})
}
func TestServer_SetSession_expired(t *testing.T) {
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{
Lifetime: durationpb.New(20 * time.Second),
})
require.NoError(t, err)
// test session token works
sessionResp, err := Tester.Client.SessionV2.SetSession(CTX, &session.SetSessionRequest{
SessionId: createResp.GetSessionId(),
SessionToken: createResp.GetSessionToken(),
Lifetime: durationpb.New(20 * time.Second),
})
require.NoError(t, err)
// ensure session expires and does not work anymore
time.Sleep(20 * time.Second)
_, err = Tester.Client.SessionV2.SetSession(CTX, &session.SetSessionRequest{
SessionId: createResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
Lifetime: durationpb.New(20 * time.Second),
})
require.Error(t, err)
}
func Test_ZITADEL_API_missing_authentication(t *testing.T) {
// create new, empty session
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{})
@@ -629,7 +681,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) {
}
func Test_ZITADEL_API_success(t *testing.T) {
id, token, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, User.GetUserId())
id, token, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId())
ctx := Tester.WithAuthorizationToken(context.Background(), token)
sessionResp, err := Tester.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id})
@@ -641,7 +693,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
}
func Test_ZITADEL_API_session_not_found(t *testing.T) {
id, token, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, User.GetUserId())
id, token, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId())
// test session token works
ctx := Tester.WithAuthorizationToken(context.Background(), token)
@@ -658,3 +710,18 @@ func Test_ZITADEL_API_session_not_found(t *testing.T) {
_, err = Tester.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id})
require.Error(t, err)
}
func Test_ZITADEL_API_session_expired(t *testing.T) {
id, token, _, _ := Tester.CreateVerifiedWebAuthNSessionWithLifetime(t, CTX, User.GetUserId(), 20*time.Second)
// test session token works
ctx := Tester.WithAuthorizationToken(context.Background(), token)
_, err := Tester.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id})
require.NoError(t, err)
// ensure session expires and does not work anymore
time.Sleep(20 * time.Second)
sessionResp, err := Tester.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id})
require.Error(t, err)
require.Nil(t, sessionResp)
}

View File

@@ -14,6 +14,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
objpb "github.com/zitadel/zitadel/pkg/grpc/object"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
@@ -22,12 +23,16 @@ import (
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
)
var (
creationDate = time.Date(2023, 10, 10, 14, 15, 0, 0, time.UTC)
)
func Test_sessionsToPb(t *testing.T) {
now := time.Now()
past := now.Add(-time.Hour)
sessions := []*query.Session{
{ // no factor, with user agent
{ // no factor, with user agent and expiration
ID: "999",
CreationDate: now,
ChangeDate: now,
@@ -42,6 +47,7 @@ func Test_sessionsToPb(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Header: http.Header{"foo": []string{"foo", "bar"}},
},
Expiration: now,
},
{ // user factor
ID: "999",
@@ -124,7 +130,7 @@ func Test_sessionsToPb(t *testing.T) {
}
want := []*session.Session{
{ // no factor, with user agent
{ // no factor, with user agent and expiration
Id: "999",
CreationDate: timestamppb.New(now),
ChangeDate: timestamppb.New(now),
@@ -139,6 +145,7 @@ func Test_sessionsToPb(t *testing.T) {
"foo": {Values: []string{"foo", "bar"}},
},
},
ExpirationDate: timestamppb.New(now),
},
{ // user factor
Id: "999",
@@ -307,11 +314,18 @@ func mustNewListQuery(t testing.TB, column query.Column, list []any, compare que
return q
}
func mustNewTimestampQuery(t testing.TB, column query.Column, ts time.Time, compare query.TimestampComparison) query.SearchQuery {
q, err := query.NewTimestampQuery(column, ts, compare)
require.NoError(t, err)
return q
}
func Test_listSessionsRequestToQuery(t *testing.T) {
type args struct {
ctx context.Context
req *session.ListSessionsRequest
}
tests := []struct {
name string
args args
@@ -335,6 +349,26 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
},
},
},
{
name: "default request with sorting column",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
req: &session.ListSessionsRequest{
SortingColumn: session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE,
},
},
want: &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: 0,
Limit: 0,
SortingColumn: query.SessionColumnCreationDate,
Asc: false,
},
Queries: []query.SearchQuery{
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
},
{
name: "with list query and sessions",
args: args{
@@ -356,6 +390,17 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
Ids: []string{"4", "5", "6"},
},
}},
{Query: &session.SearchQuery_UserIdQuery{
UserIdQuery: &session.UserIDQuery{
Id: "10",
},
}},
{Query: &session.SearchQuery_CreationDateQuery{
CreationDateQuery: &session.CreationDateQuery{
CreationDate: timestamppb.New(creationDate),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER,
},
}},
},
},
},
@@ -368,6 +413,8 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
Queries: []query.SearchQuery{
mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
mustNewListQuery(t, query.SessionColumnID, []interface{}{"4", "5", "6"}, query.ListIn),
mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals),
mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater),
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
@@ -485,7 +532,7 @@ func Test_sessionQueryToQuery(t *testing.T) {
wantErr: caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"),
},
{
name: "query",
name: "ids query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
@@ -495,6 +542,40 @@ func Test_sessionQueryToQuery(t *testing.T) {
}},
want: mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
},
{
name: "user id query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_UserIdQuery{
UserIdQuery: &session.UserIDQuery{
Id: "10",
},
},
}},
want: mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals),
},
{
name: "creation date query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_CreationDateQuery{
CreationDateQuery: &session.CreationDateQuery{
CreationDate: timestamppb.New(creationDate),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS,
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampLess),
},
{
name: "creation date query with default method",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_CreationDateQuery{
CreationDateQuery: &session.CreationDateQuery{
CreationDate: timestamppb.New(creationDate),
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampEquals),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -7,10 +7,10 @@ import (
user_pb "github.com/zitadel/zitadel/pkg/grpc/user"
)
func UserQueriesToQuery(queries []*user_pb.SearchQuery) (_ []query.SearchQuery, err error) {
func UserQueriesToQuery(queries []*user_pb.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries))
for i, query := range queries {
q[i], err = UserQueryToQuery(query)
q[i], err = UserQueryToQuery(query, level)
if err != nil {
return nil, err
}
@@ -18,7 +18,11 @@ func UserQueriesToQuery(queries []*user_pb.SearchQuery) (_ []query.SearchQuery,
return q, nil
}
func UserQueryToQuery(query *user_pb.SearchQuery) (query.SearchQuery, error) {
func UserQueryToQuery(query *user_pb.SearchQuery, level uint8) (query.SearchQuery, error) {
if level > 20 {
// can't go deeper than 20 levels of nesting.
return nil, errors.ThrowInvalidArgument(nil, "USER-zsQ97", "Errors.User.TooManyNestingLevels")
}
switch q := query.Query.(type) {
case *user_pb.SearchQuery_UserNameQuery:
return UserNameQueryToQuery(q.UserNameQuery)
@@ -42,6 +46,12 @@ func UserQueryToQuery(query *user_pb.SearchQuery) (query.SearchQuery, error) {
return ResourceOwnerQueryToQuery(q.ResourceOwner)
case *user_pb.SearchQuery_InUserIdsQuery:
return InUserIdsQueryToQuery(q.InUserIdsQuery)
case *user_pb.SearchQuery_OrQuery:
return OrQueryToQuery(q.OrQuery, level)
case *user_pb.SearchQuery_AndQuery:
return AndQueryToQuery(q.AndQuery, level)
case *user_pb.SearchQuery_NotQuery:
return NotQueryToQuery(q.NotQuery, level)
default:
return nil, errors.ThrowInvalidArgument(nil, "GRPC-vR9nC", "List.Query.Invalid")
}
@@ -90,3 +100,24 @@ func ResourceOwnerQueryToQuery(q *user_pb.ResourceOwnerQuery) (query.SearchQuery
func InUserIdsQueryToQuery(q *user_pb.InUserIDQuery) (query.SearchQuery, error) {
return query.NewUserInUserIdsSearchQuery(q.UserIds)
}
func OrQueryToQuery(q *user_pb.OrQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, err := UserQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}
return query.NewUserOrSearchQuery(mappedQueries)
}
func AndQueryToQuery(q *user_pb.AndQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, err := UserQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}
return query.NewUserAndSearchQuery(mappedQueries)
}
func NotQueryToQuery(q *user_pb.NotQuery, level uint8) (query.SearchQuery, error) {
mappedQuery, err := UserQueryToQuery(q.Query, level+1)
if err != nil {
return nil, err
}
return query.NewUserNotSearchQuery(mappedQuery)
}

View File

@@ -16,13 +16,13 @@ import (
func TestServer_AddOTPSMS(t *testing.T) {
userID := Tester.CreateHumanUser(CTX).GetUserId()
Tester.RegisterUserPasskey(CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userID)
// TODO: add when phone can be added to user
/*
userIDPhone := Tester.CreateHumanUser(CTX).GetUserId()
Tester.RegisterUserPasskey(CTX, userIDPhone)
_, sessionTokenPhone, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userIDPhone)
_, sessionTokenPhone, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userIDPhone)
*/
type args struct {
ctx context.Context
@@ -99,7 +99,7 @@ func TestServer_RemoveOTPSMS(t *testing.T) {
/*
userID := Tester.CreateHumanUser(CTX).GetUserId()
Tester.RegisterUserPasskey(CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userID)
*/
type args struct {
@@ -157,7 +157,7 @@ func TestServer_RemoveOTPSMS(t *testing.T) {
func TestServer_AddOTPEmail(t *testing.T) {
userID := Tester.CreateHumanUser(CTX).GetUserId()
Tester.RegisterUserPasskey(CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userID)
userVerified := Tester.CreateHumanUser(CTX)
_, err := Tester.Client.UserV2.VerifyEmail(CTX, &user.VerifyEmailRequest{
@@ -166,7 +166,7 @@ func TestServer_AddOTPEmail(t *testing.T) {
})
require.NoError(t, err)
Tester.RegisterUserPasskey(CTX, userVerified.GetUserId())
_, sessionTokenVerified, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userVerified.GetUserId())
_, sessionTokenVerified, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
type args struct {
ctx context.Context
@@ -238,11 +238,11 @@ func TestServer_AddOTPEmail(t *testing.T) {
func TestServer_RemoveOTPEmail(t *testing.T) {
userID := Tester.CreateHumanUser(CTX).GetUserId()
Tester.RegisterUserPasskey(CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userID)
userVerified := Tester.CreateHumanUser(CTX)
Tester.RegisterUserPasskey(CTX, userVerified.GetUserId())
_, sessionTokenVerified, _, _ := Tester.CreateVerfiedWebAuthNSession(t, CTX, userVerified.GetUserId())
_, sessionTokenVerified, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
userVerifiedCtx := Tester.WithAuthorizationToken(context.Background(), sessionTokenVerified)
_, err := Tester.Client.UserV2.VerifyEmail(userVerifiedCtx, &user.VerifyEmailRequest{
UserId: userVerified.GetUserId(),