package middleware import ( "context" "fmt" "net/http" "net/http/httptest" "reflect" "testing" "time" "github.com/stretchr/testify/assert" "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" zitadel_http "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/feature" ) func Test_instanceInterceptor_Handler(t *testing.T) { type fields struct { verifier authz.InstanceVerifier } type args struct { request *http.Request } type res struct { statusCode int context context.Context } tests := []struct { name string fields fields args args res res }{ { "setInstance error", fields{ verifier: &mockInstanceVerifier{}, }, args{ request: httptest.NewRequest("", "/url", nil), }, res{ statusCode: 404, context: nil, }, }, { "setInstance ok", fields{ verifier: &mockInstanceVerifier{instanceHost: "host"}, }, args{ request: func() *http.Request { r := httptest.NewRequest("", "/url", nil) r = r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host"})) return r }(), }, res{ statusCode: 200, context: authz.WithInstance(zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host"}), &mockInstance{}), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &instanceInterceptor{ verifier: tt.fields.verifier, translator: newZitadelTranslator(), } next := &testHandler{} got := a.HandlerFunc(next.ServeHTTP) rr := httptest.NewRecorder() got.ServeHTTP(rr, tt.args.request) assert.Equal(t, tt.res.statusCode, rr.Code) assert.Equal(t, tt.res.context, next.context) }) } } func Test_instanceInterceptor_HandlerFunc(t *testing.T) { type fields struct { verifier authz.InstanceVerifier } type args struct { request *http.Request } type res struct { statusCode int context context.Context } tests := []struct { name string fields fields args args res res }{ { "setInstance error", fields{ verifier: &mockInstanceVerifier{}, }, args{ request: httptest.NewRequest("", "/url", nil), }, res{ statusCode: 404, context: nil, }, }, { "setInstance ok", fields{ verifier: &mockInstanceVerifier{instanceHost: "host"}, }, args{ request: func() *http.Request { r := httptest.NewRequest("", "/url", nil) r = r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host"})) return r }(), }, res{ statusCode: 200, context: authz.WithInstance(zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host"}), &mockInstance{}), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &instanceInterceptor{ verifier: tt.fields.verifier, translator: newZitadelTranslator(), } next := &testHandler{} got := a.HandlerFunc(next.ServeHTTP) rr := httptest.NewRecorder() got.ServeHTTP(rr, tt.args.request) assert.Equal(t, tt.res.statusCode, rr.Code) assert.Equal(t, tt.res.context, next.context) }) } } func Test_setInstance(t *testing.T) { type args struct { r *http.Request verifier authz.InstanceVerifier } type res struct { want context.Context err bool } tests := []struct { name string args args res res }{ { "no domain context, not found error", args{ r: func() *http.Request { r := httptest.NewRequest("", "/url", nil) return r }(), verifier: &mockInstanceVerifier{}, }, res{ want: nil, err: true, }, }, { "instanceHost found, ok", args{ r: func() *http.Request { r := httptest.NewRequest("", "/url", nil) return r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host", Protocol: "https"})) }(), verifier: &mockInstanceVerifier{instanceHost: "host"}, }, res{ want: authz.WithInstance(zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host", Protocol: "https"}), &mockInstance{}), err: false, }, }, { "instanceHost not found, error", args{ r: func() *http.Request { r := httptest.NewRequest("", "/url", nil) return r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "fromorigin:9999", Protocol: "https"})) }(), verifier: &mockInstanceVerifier{instanceHost: "unknowndomain"}, }, res{ want: nil, err: true, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := setInstance(tt.args.r, tt.args.verifier) if (err != nil) != tt.res.err { t.Errorf("setInstance() error = %v, wantErr %v", err, tt.res.err) return } if !reflect.DeepEqual(got, tt.res.want) { t.Errorf("setInstance() got = %v, want %v", got, tt.res.want) } }) } } type testHandler struct { context context.Context } func (t *testHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { t.context = r.Context() } type mockInstanceVerifier struct { instanceHost string publicHost string } func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, instanceHost, publicHost string) (authz.Instance, error) { if instanceHost != m.instanceHost { return nil, fmt.Errorf("invalid host") } if publicHost == "" { return &mockInstance{}, nil } if publicHost != instanceHost && publicHost != m.publicHost { return nil, fmt.Errorf("invalid host") } return &mockInstance{}, nil } func (m *mockInstanceVerifier) InstanceByID(context.Context, string) (authz.Instance, error) { return nil, nil } type mockInstance struct{} func (m *mockInstance) Block() *bool { panic("shouldn't be called here") } func (m *mockInstance) AuditLogRetention() *time.Duration { panic("shouldn't be called here") } func (m *mockInstance) InstanceID() string { return "instanceID" } func (m *mockInstance) ProjectID() string { return "projectID" } func (m *mockInstance) ConsoleClientID() string { return "consoleClientID" } func (m *mockInstance) ConsoleApplicationID() string { return "consoleApplicationID" } func (m *mockInstance) DefaultLanguage() language.Tag { return language.English } func (m *mockInstance) DefaultOrganisationID() string { return "orgID" } func (m *mockInstance) SecurityPolicyAllowedOrigins() []string { return nil } func (m *mockInstance) EnableImpersonation() bool { return false } func (m *mockInstance) Features() feature.Features { return feature.Features{} }