mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-12 19:14:23 +00:00
f4c72cbe14
* feat: improve instance not found error * unit tests * check if is templatable * lint * assert * compile tests * remove error templates * link to instance not found page * fmt * cleanup * lint
220 lines
4.4 KiB
Go
220 lines
4.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/text/language"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
"github.com/zitadel/zitadel/internal/feature"
|
|
)
|
|
|
|
func Test_hostNameFromContext(t *testing.T) {
|
|
type args struct {
|
|
ctx context.Context
|
|
headerName string
|
|
}
|
|
type res struct {
|
|
want string
|
|
err bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
res res
|
|
}{
|
|
{
|
|
"empty context, error",
|
|
args{
|
|
ctx: context.Background(),
|
|
headerName: "header",
|
|
},
|
|
res{
|
|
want: "",
|
|
err: true,
|
|
},
|
|
},
|
|
{
|
|
"header not found",
|
|
args{
|
|
ctx: metadata.NewIncomingContext(context.Background(), nil),
|
|
headerName: "header",
|
|
},
|
|
res{
|
|
want: "",
|
|
err: true,
|
|
},
|
|
},
|
|
{
|
|
"header not found",
|
|
args{
|
|
ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("header", "value")),
|
|
headerName: "header",
|
|
},
|
|
res{
|
|
want: "value",
|
|
err: false,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := hostFromContext(tt.args.ctx, tt.args.headerName)
|
|
if (err != nil) != tt.res.err {
|
|
t.Errorf("hostFromContext() error = %v, wantErr %v", err, tt.res.err)
|
|
return
|
|
}
|
|
if got != tt.res.want {
|
|
t.Errorf("hostFromContext() got = %v, want %v", got, tt.res.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_setInstance(t *testing.T) {
|
|
type args struct {
|
|
ctx context.Context
|
|
req interface{}
|
|
info *grpc.UnaryServerInfo
|
|
handler grpc.UnaryHandler
|
|
verifier authz.InstanceVerifier
|
|
headerName string
|
|
}
|
|
type res struct {
|
|
want interface{}
|
|
err bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
res res
|
|
}{
|
|
{
|
|
"hostname not found, error",
|
|
args{
|
|
ctx: context.Background(),
|
|
},
|
|
res{
|
|
want: nil,
|
|
err: true,
|
|
},
|
|
},
|
|
{
|
|
"invalid host, error",
|
|
args{
|
|
ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("header", "host2")),
|
|
req: &mockRequest{},
|
|
verifier: &mockInstanceVerifier{"host"},
|
|
headerName: "header",
|
|
},
|
|
res{
|
|
want: nil,
|
|
err: true,
|
|
},
|
|
},
|
|
{
|
|
"valid host",
|
|
args{
|
|
ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("header", "host")),
|
|
req: &mockRequest{},
|
|
verifier: &mockInstanceVerifier{"host"},
|
|
headerName: "header",
|
|
handler: func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
return req, nil
|
|
},
|
|
},
|
|
res{
|
|
want: &mockRequest{},
|
|
err: false,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := setInstance(tt.args.ctx, tt.args.req, tt.args.info, tt.args.handler, tt.args.verifier, tt.args.headerName, "", nil)
|
|
if (err != nil) != tt.res.err {
|
|
t.Errorf("setInstance() error = %v, wantErr %v", err, tt.res.err)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.res.want) {
|
|
t.Errorf("setInstance() got = %v, want %v", got, tt.res.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockRequest struct{}
|
|
|
|
type mockInstanceVerifier struct {
|
|
host string
|
|
}
|
|
|
|
func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (authz.Instance, error) {
|
|
if host != m.host {
|
|
return nil, fmt.Errorf("invalid host")
|
|
}
|
|
return &mockInstance{}, nil
|
|
}
|
|
|
|
func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) { return nil, nil }
|
|
|
|
type mockInstance struct{}
|
|
|
|
func (m *mockInstance) Block() *bool {
|
|
panic("shouldn't be called here")
|
|
}
|
|
|
|
func (m *mockInstance) AuditLogRetention() *time.Duration {
|
|
panic("shouldn't be called here")
|
|
}
|
|
|
|
func (m *mockInstance) InstanceID() string {
|
|
return "instanceID"
|
|
}
|
|
|
|
func (m *mockInstance) ProjectID() string {
|
|
return "projectID"
|
|
}
|
|
|
|
func (m *mockInstance) ConsoleClientID() string {
|
|
return "consoleClientID"
|
|
}
|
|
|
|
func (m *mockInstance) ConsoleApplicationID() string {
|
|
return "consoleApplicationID"
|
|
}
|
|
|
|
func (m *mockInstance) DefaultLanguage() language.Tag {
|
|
return language.English
|
|
}
|
|
|
|
func (m *mockInstance) DefaultOrganisationID() string {
|
|
return "orgID"
|
|
}
|
|
|
|
func (m *mockInstance) RequestedDomain() string {
|
|
return "localhost"
|
|
}
|
|
|
|
func (m *mockInstance) RequestedHost() string {
|
|
return "localhost:8080"
|
|
}
|
|
|
|
func (m *mockInstance) SecurityPolicyAllowedOrigins() []string {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockInstance) EnableImpersonation() bool {
|
|
return false
|
|
}
|
|
|
|
func (m *mockInstance) Features() feature.Features {
|
|
return feature.Features{}
|
|
}
|