diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 90c2db1f01..a12fe474ba 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -600,7 +600,10 @@ Actions: # Wildcard sub domains are currently unsupported DenyList: # ZITADEL_ACTIONS_HTTP_DENYLIST (comma separated list) - localhost - - "127.0.0.1" + - "127.0.0.0/8" + - "::1" + - "0.0.0.0" + - "::" LogStore: Access: diff --git a/cmd/start/config_test.go b/cmd/start/config_test.go index 90d4b9d2dc..53c95d35ab 100644 --- a/cmd/start/config_test.go +++ b/cmd/start/config_test.go @@ -47,9 +47,9 @@ Log: `}, want: func(t *testing.T, config *Config) { assert.Equal(t, config.Actions.HTTP.DenyList, []actions.AddressChecker{ - &actions.DomainChecker{Domain: "localhost"}, - &actions.IPChecker{IP: net.ParseIP("127.0.0.1")}, - &actions.DomainChecker{Domain: "foobar"}}) + &actions.HostChecker{Domain: "localhost"}, + &actions.HostChecker{IP: net.ParseIP("127.0.0.1")}, + &actions.HostChecker{Domain: "foobar"}}) }, }, { name: "actions deny list string ok", @@ -63,9 +63,9 @@ Log: `}, want: func(t *testing.T, config *Config) { assert.Equal(t, config.Actions.HTTP.DenyList, []actions.AddressChecker{ - &actions.DomainChecker{Domain: "localhost"}, - &actions.IPChecker{IP: net.ParseIP("127.0.0.1")}, - &actions.DomainChecker{Domain: "foobar"}}) + &actions.HostChecker{Domain: "localhost"}, + &actions.HostChecker{IP: net.ParseIP("127.0.0.1")}, + &actions.HostChecker{Domain: "foobar"}}) }, }, { name: "features ok", diff --git a/internal/actions/http_module.go b/internal/actions/http_module.go index 33cfbc91bc..2f9d09932c 100644 --- a/internal/actions/http_module.go +++ b/internal/actions/http_module.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "io" + "net" "net/http" "net/url" "strings" @@ -19,7 +20,7 @@ import ( func WithHTTP(ctx context.Context) Option { return func(c *runConfig) { c.modules["zitadel/http"] = func(runtime *goja.Runtime, module *goja.Object) { - requireHTTP(ctx, &http.Client{Transport: new(transport)}, runtime, module) + requireHTTP(ctx, &http.Client{Transport: &transport{lookup: net.LookupIP}}, runtime, module) } } } @@ -170,21 +171,34 @@ func parseHeaders(headers *goja.Object) http.Header { return h } -type transport struct{} +type transport struct { + lookup func(string) ([]net.IP, error) +} -func (*transport) RoundTrip(req *http.Request) (*http.Response, error) { +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { if httpConfig == nil { return http.DefaultTransport.RoundTrip(req) } - if isHostBlocked(httpConfig.DenyList, req.URL) { + if t.isHostBlocked(httpConfig.DenyList, req.URL) { return nil, zerrors.ThrowInvalidArgument(nil, "ACTIO-N72d0", "host is denied") } return http.DefaultTransport.RoundTrip(req) } -func isHostBlocked(denyList []AddressChecker, address *url.URL) bool { +func (t *transport) isHostBlocked(denyList []AddressChecker, address *url.URL) bool { + host := address.Hostname() + ip := net.ParseIP(host) + ips := []net.IP{ip} + // if the hostname is a domain, we need to check resolve the ip(s), since it might be denied + if ip == nil { + var err error + ips, err = t.lookup(host) + if err != nil { + return true + } + } for _, blocked := range denyList { - if blocked.Matches(address.Hostname()) { + if blocked.Matches(ips, host) { return true } } @@ -192,5 +206,5 @@ func isHostBlocked(denyList []AddressChecker, address *url.URL) bool { } type AddressChecker interface { - Matches(string) bool + Matches([]net.IP, string) bool } diff --git a/internal/actions/http_module_config.go b/internal/actions/http_module_config.go index d10ad39676..d1b965814e 100644 --- a/internal/actions/http_module_config.go +++ b/internal/actions/http_module_config.go @@ -6,8 +6,6 @@ import ( "strings" "github.com/mitchellh/mapstructure" - - "github.com/zitadel/zitadel/internal/zerrors" ) func SetHTTPConfig(config *HTTPConfig) { @@ -48,7 +46,7 @@ func HTTPConfigDecodeHook(from, to reflect.Value) (interface{}, error) { for _, unsplit := range config.DenyList { for _, split := range strings.Split(unsplit, ",") { - parsed, parseErr := parseDenyListEntry(split) + parsed, parseErr := NewHostChecker(split) if parseErr != nil { return nil, parseErr } @@ -61,46 +59,36 @@ func HTTPConfigDecodeHook(from, to reflect.Value) (interface{}, error) { return c, nil } -func parseDenyListEntry(entry string) (AddressChecker, error) { - if checker, err := NewIPChecker(entry); err == nil { - return checker, nil - } - return &DomainChecker{Domain: entry}, nil -} - -func NewIPChecker(i string) (AddressChecker, error) { - _, network, err := net.ParseCIDR(i) +func NewHostChecker(entry string) (AddressChecker, error) { + _, network, err := net.ParseCIDR(entry) if err == nil { - return &IPChecker{Net: network}, nil + return &HostChecker{Net: network}, nil } - if ip := net.ParseIP(i); ip != nil { - return &IPChecker{IP: ip}, nil + if ip := net.ParseIP(entry); ip != nil { + return &HostChecker{IP: ip}, nil } - return nil, zerrors.ThrowInvalidArgument(nil, "ACTIO-ddJ7h", "invalid ip") + return &HostChecker{Domain: entry}, nil } -type IPChecker struct { - Net *net.IPNet - IP net.IP -} - -func (c *IPChecker) Matches(address string) bool { - ip := net.ParseIP(address) - if ip == nil { - return false - } - - if c.IP != nil { - return c.IP.Equal(ip) - } - return c.Net.Contains(ip) -} - -type DomainChecker struct { +type HostChecker struct { + Net *net.IPNet + IP net.IP Domain string } -func (c *DomainChecker) Matches(domain string) bool { - //TODO: allow wild cards - return c.Domain == domain +func (c *HostChecker) Matches(ips []net.IP, address string) bool { + // if the address matches the domain, no additional checks as needed + if c.Domain == address { + return true + } + // otherwise we need to check on ips (incl. the resolved ips of the host) + for _, ip := range ips { + if c.Net != nil && c.Net.Contains(ip) { + return true + } + if c.IP != nil && c.IP.Equal(ip) { + return true + } + } + return false } diff --git a/internal/actions/http_module_test.go b/internal/actions/http_module_test.go index 0d3bdef75e..7a1f8d7816 100644 --- a/internal/actions/http_module_test.go +++ b/internal/actions/http_module_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "net" "net/http" "net/url" "reflect" @@ -19,17 +20,21 @@ import ( func Test_isHostBlocked(t *testing.T) { SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil)) var denyList = []AddressChecker{ - mustNewIPChecker(t, "192.168.5.0/24"), - mustNewIPChecker(t, "127.0.0.1"), - &DomainChecker{Domain: "test.com"}, + mustNewHostChecker(t, "192.168.5.0/24"), + mustNewHostChecker(t, "127.0.0.1"), + mustNewHostChecker(t, "test.com"), + } + type fields struct { + lookup func(host string) ([]net.IP, error) } type args struct { address *url.URL } tests := []struct { - name string - args args - want bool + name string + fields fields + args args + want bool }{ { name: "in range", @@ -47,6 +52,11 @@ func Test_isHostBlocked(t *testing.T) { }, { name: "address match", + fields: fields{ + lookup: func(host string) ([]net.IP, error) { + return []net.IP{net.ParseIP("194.264.52.4")}, nil + }, + }, args: args{ address: mustNewURL(t, "https://test.com:42/hodor"), }, @@ -54,24 +64,44 @@ func Test_isHostBlocked(t *testing.T) { }, { name: "address not match", + fields: fields{ + lookup: func(host string) ([]net.IP, error) { + return []net.IP{net.ParseIP("194.264.52.4")}, nil + }, + }, args: args{ address: mustNewURL(t, "https://test2.com/hodor"), }, want: false, }, + { + name: "looked up ip matches", + fields: fields{ + lookup: func(host string) ([]net.IP, error) { + return []net.IP{net.ParseIP("127.0.0.1")}, nil + }, + }, + args: args{ + address: mustNewURL(t, "https://test2.com/hodor"), + }, + want: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := isHostBlocked(denyList, tt.args.address); got != tt.want { + trans := &transport{ + lookup: tt.fields.lookup, + } + if got := trans.isHostBlocked(denyList, tt.args.address); got != tt.want { t.Errorf("isHostBlocked() = %v, want %v", got, tt.want) } }) } } -func mustNewIPChecker(t *testing.T, ip string) AddressChecker { +func mustNewHostChecker(t *testing.T, ip string) AddressChecker { t.Helper() - checker, err := NewIPChecker(ip) + checker, err := NewHostChecker(ip) if err != nil { t.Errorf("unable to parse cidr of %q because: %v", ip, err) t.FailNow()