diff --git a/cmd/start/config.go b/cmd/start/config.go index 910759b653..5e28e1dfb2 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -125,7 +125,9 @@ func MustNewConfig(v *viper.Viper) *Config { logging.OnError(err).Fatal("unable to set profiler") id.Configure(config.Machine) - actions.SetHTTPConfig(&config.Actions.HTTP) + if config.Actions != nil { + actions.SetHTTPConfig(&config.Actions.HTTP) + } // Copy the global role permissions mappings to the instance until we allow instance-level configuration over the API. config.DefaultInstance.RolePermissionMappings = config.InternalAuthZ.RolePermissionMappings diff --git a/internal/actions/http_module.go b/internal/actions/http_module.go index 2f9d09932c..db7253428d 100644 --- a/internal/actions/http_module.go +++ b/internal/actions/http_module.go @@ -176,16 +176,16 @@ type transport struct { } func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { - if httpConfig == nil { + if httpConfig == nil || len(httpConfig.DenyList) == 0 { return http.DefaultTransport.RoundTrip(req) } - if t.isHostBlocked(httpConfig.DenyList, req.URL) { - return nil, zerrors.ThrowInvalidArgument(nil, "ACTIO-N72d0", "host is denied") + if err := t.isHostBlocked(httpConfig.DenyList, req.URL); err != nil { + return nil, zerrors.ThrowInvalidArgument(err, "ACTIO-N72d0", "host is denied") } return http.DefaultTransport.RoundTrip(req) } -func (t *transport) isHostBlocked(denyList []AddressChecker, address *url.URL) bool { +func (t *transport) isHostBlocked(denyList []AddressChecker, address *url.URL) error { host := address.Hostname() ip := net.ParseIP(host) ips := []net.IP{ip} @@ -194,17 +194,17 @@ func (t *transport) isHostBlocked(denyList []AddressChecker, address *url.URL) b var err error ips, err = t.lookup(host) if err != nil { - return true + return zerrors.ThrowInternal(err, "ACTIO-4m9s2", "lookup failed") } } - for _, blocked := range denyList { - if blocked.Matches(ips, host) { - return true + for _, denied := range denyList { + if err := denied.IsDenied(ips, host); err != nil { + return err } } - return false + return nil } type AddressChecker interface { - Matches([]net.IP, string) bool + IsDenied([]net.IP, string) error } diff --git a/internal/actions/http_module_config.go b/internal/actions/http_module_config.go index d1b965814e..eaab9e754e 100644 --- a/internal/actions/http_module_config.go +++ b/internal/actions/http_module_config.go @@ -1,6 +1,8 @@ package actions import ( + "errors" + "fmt" "net" "reflect" "strings" @@ -60,6 +62,9 @@ func HTTPConfigDecodeHook(from, to reflect.Value) (interface{}, error) { } func NewHostChecker(entry string) (AddressChecker, error) { + if entry == "" { + return nil, nil + } _, network, err := net.ParseCIDR(entry) if err == nil { return &HostChecker{Net: network}, nil @@ -76,19 +81,39 @@ type HostChecker struct { Domain string } -func (c *HostChecker) Matches(ips []net.IP, address string) bool { +type AddressDeniedError struct { + deniedBy string +} + +func NewAddressDeniedError(deniedBy string) *AddressDeniedError { + return &AddressDeniedError{deniedBy: deniedBy} +} + +func (e *AddressDeniedError) Error() string { + return fmt.Sprintf("address is denied by '%s'", e.deniedBy) +} + +func (e *AddressDeniedError) Is(target error) bool { + var addressDeniedErr *AddressDeniedError + if !errors.As(target, &addressDeniedErr) { + return false + } + return e.deniedBy == addressDeniedErr.deniedBy +} + +func (c *HostChecker) IsDenied(ips []net.IP, address string) error { // if the address matches the domain, no additional checks as needed if c.Domain == address { - return true + return NewAddressDeniedError(c.Domain) } // otherwise we need to check on ips (incl. the resolved ips of the host) for _, ip := range ips { if c.Net != nil && c.Net.Contains(ip) { - return true + return NewAddressDeniedError(c.Net.String()) } if c.IP != nil && c.IP.Equal(ip) { - return true + return NewAddressDeniedError(c.IP.String()) } } - return false + return nil } diff --git a/internal/actions/http_module_test.go b/internal/actions/http_module_test.go index 7a1f8d7816..50a007feeb 100644 --- a/internal/actions/http_module_test.go +++ b/internal/actions/http_module_test.go @@ -3,6 +3,7 @@ package actions import ( "bytes" "context" + "errors" "io" "net" "net/http" @@ -11,6 +12,7 @@ import ( "testing" "github.com/dop251/goja" + "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/internal/logstore" "github.com/zitadel/zitadel/internal/logstore/record" @@ -34,21 +36,21 @@ func Test_isHostBlocked(t *testing.T) { name string fields fields args args - want bool + want error }{ { name: "in range", args: args{ address: mustNewURL(t, "https://192.168.5.4/hodor"), }, - want: true, + want: NewAddressDeniedError("192.168.5.0/24"), }, { name: "exact ip", args: args{ address: mustNewURL(t, "http://127.0.0.1:8080/hodor"), }, - want: true, + want: NewAddressDeniedError("127.0.0.1"), }, { name: "address match", @@ -60,7 +62,7 @@ func Test_isHostBlocked(t *testing.T) { args: args{ address: mustNewURL(t, "https://test.com:42/hodor"), }, - want: true, + want: NewAddressDeniedError("test.com"), }, { name: "address not match", @@ -72,7 +74,7 @@ func Test_isHostBlocked(t *testing.T) { args: args{ address: mustNewURL(t, "https://test2.com/hodor"), }, - want: false, + want: nil, }, { name: "looked up ip matches", @@ -84,7 +86,19 @@ func Test_isHostBlocked(t *testing.T) { args: args{ address: mustNewURL(t, "https://test2.com/hodor"), }, - want: true, + want: NewAddressDeniedError("127.0.0.1"), + }, + { + name: "looked up failure", + fields: fields{ + lookup: func(host string) ([]net.IP, error) { + return nil, errors.New("some error") + }, + }, + args: args{ + address: mustNewURL(t, "https://test2.com/hodor"), + }, + want: zerrors.ThrowInternal(nil, "ACTIO-4m9s2", "lookup failed"), }, } for _, tt := range tests { @@ -92,9 +106,8 @@ func Test_isHostBlocked(t *testing.T) { trans := &transport{ lookup: tt.fields.lookup, } - if got := trans.isHostBlocked(denyList, tt.args.address); got != tt.want { - t.Errorf("isHostBlocked() = %v, want %v", got, tt.want) - } + got := trans.isHostBlocked(denyList, tt.args.address) + assert.ErrorIs(t, got, tt.want) }) } }