zitadel/internal/api/grpc/server/middleware/instance_interceptor_test.go
Elio Bischof f4c72cbe14
feat: improve instance not found error (#7413)
* feat: improve instance not found error

* unit tests

* check if is templatable

* lint

* assert

* compile tests

* remove error templates

* link to instance not found page

* fmt

* cleanup

* lint
2024-02-28 10:49:57 +00:00

220 lines
4.4 KiB
Go

package middleware
import (
"context"
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/text/language"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/feature"
)
func Test_hostNameFromContext(t *testing.T) {
type args struct {
ctx context.Context
headerName string
}
type res struct {
want string
err bool
}
tests := []struct {
name string
args args
res res
}{
{
"empty context, error",
args{
ctx: context.Background(),
headerName: "header",
},
res{
want: "",
err: true,
},
},
{
"header not found",
args{
ctx: metadata.NewIncomingContext(context.Background(), nil),
headerName: "header",
},
res{
want: "",
err: true,
},
},
{
"header not found",
args{
ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("header", "value")),
headerName: "header",
},
res{
want: "value",
err: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := hostFromContext(tt.args.ctx, tt.args.headerName)
if (err != nil) != tt.res.err {
t.Errorf("hostFromContext() error = %v, wantErr %v", err, tt.res.err)
return
}
if got != tt.res.want {
t.Errorf("hostFromContext() got = %v, want %v", got, tt.res.want)
}
})
}
}
func Test_setInstance(t *testing.T) {
type args struct {
ctx context.Context
req interface{}
info *grpc.UnaryServerInfo
handler grpc.UnaryHandler
verifier authz.InstanceVerifier
headerName string
}
type res struct {
want interface{}
err bool
}
tests := []struct {
name string
args args
res res
}{
{
"hostname not found, error",
args{
ctx: context.Background(),
},
res{
want: nil,
err: true,
},
},
{
"invalid host, error",
args{
ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("header", "host2")),
req: &mockRequest{},
verifier: &mockInstanceVerifier{"host"},
headerName: "header",
},
res{
want: nil,
err: true,
},
},
{
"valid host",
args{
ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("header", "host")),
req: &mockRequest{},
verifier: &mockInstanceVerifier{"host"},
headerName: "header",
handler: func(ctx context.Context, req interface{}) (interface{}, error) {
return req, nil
},
},
res{
want: &mockRequest{},
err: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := setInstance(tt.args.ctx, tt.args.req, tt.args.info, tt.args.handler, tt.args.verifier, tt.args.headerName, "", nil)
if (err != nil) != tt.res.err {
t.Errorf("setInstance() error = %v, wantErr %v", err, tt.res.err)
return
}
if !reflect.DeepEqual(got, tt.res.want) {
t.Errorf("setInstance() got = %v, want %v", got, tt.res.want)
}
})
}
}
type mockRequest struct{}
type mockInstanceVerifier struct {
host string
}
func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (authz.Instance, error) {
if host != m.host {
return nil, fmt.Errorf("invalid host")
}
return &mockInstance{}, nil
}
func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) { return nil, nil }
type mockInstance struct{}
func (m *mockInstance) Block() *bool {
panic("shouldn't be called here")
}
func (m *mockInstance) AuditLogRetention() *time.Duration {
panic("shouldn't be called here")
}
func (m *mockInstance) InstanceID() string {
return "instanceID"
}
func (m *mockInstance) ProjectID() string {
return "projectID"
}
func (m *mockInstance) ConsoleClientID() string {
return "consoleClientID"
}
func (m *mockInstance) ConsoleApplicationID() string {
return "consoleApplicationID"
}
func (m *mockInstance) DefaultLanguage() language.Tag {
return language.English
}
func (m *mockInstance) DefaultOrganisationID() string {
return "orgID"
}
func (m *mockInstance) RequestedDomain() string {
return "localhost"
}
func (m *mockInstance) RequestedHost() string {
return "localhost:8080"
}
func (m *mockInstance) SecurityPolicyAllowedOrigins() []string {
return nil
}
func (m *mockInstance) EnableImpersonation() bool {
return false
}
func (m *mockInstance) Features() feature.Features {
return feature.Features{}
}