fix: improve exhausted SetCookie header (#5789)

* fix: remove access interceptor for console

* feat: template quota cookie value

* fix: send exhausted cookie from grpc-gateway

* refactor: remove ineffectual err assignments

* Update internal/api/grpc/server/gateway.go

Co-authored-by: Livio Spring <livio.a@gmail.com>

* use dynamic host header to find instance

* add instance mgmt url to environment.json

* support hosts with default ports

* fix linting

* docs: update lb example

* print access logs to stdout

* fix grpc gateway exhausted cookies

* cleanup

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Elio Bischof 2023-05-11 09:24:44 +02:00 committed by GitHub
parent c2cb84cd24
commit 35a0977663
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 208 additions and 63 deletions

View File

@ -267,6 +267,7 @@ Console:
LongCache: LongCache:
MaxAge: 12h MaxAge: 12h
SharedMaxAge: 168h #7d SharedMaxAge: 168h #7d
InstanceManagementURL: ""
Notification: Notification:
Repository: Repository:

View File

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
_ "embed" _ "embed"
"fmt" "fmt"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -301,8 +302,14 @@ func startAPIs(
if accessSvc.Enabled() { if accessSvc.Enabled() {
logging.Warn("access logs are currently in beta") logging.Warn("access logs are currently in beta")
} }
accessInterceptor := middleware.NewAccessInterceptor(accessSvc, config.Quotas.Access) exhaustedCookieHandler := http_util.NewCookieHandler(
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc) http_util.WithUnsecure(),
http_util.WithNonHttpOnly(),
http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))),
)
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access, false)
nonLimitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, nil, config.Quotas.Access, true)
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc, exhaustedCookieHandler, config.Quotas.Access)
if err != nil { if err != nil {
return fmt.Errorf("error creating api %w", err) return fmt.Errorf("error creating api %w", err)
} }
@ -334,7 +341,7 @@ func startAPIs(
} }
instanceInterceptor := middleware.InstanceInterceptor(queries, config.HTTP1HostHeader, login.IgnoreInstanceEndpoints...) instanceInterceptor := middleware.InstanceInterceptor(queries, config.HTTP1HostHeader, login.IgnoreInstanceEndpoints...)
assetsCache := middleware.AssetsCacheInterceptor(config.AssetStorage.Cache.MaxAge, config.AssetStorage.Cache.SharedMaxAge) assetsCache := middleware.AssetsCacheInterceptor(config.AssetStorage.Cache.MaxAge, config.AssetStorage.Cache.SharedMaxAge)
apis.RegisterHandlerOnPrefix(assets.HandlerPrefix, assets.NewHandler(commands, verifier, config.InternalAuthZ, id.SonyFlakeGenerator(), store, queries, middleware.CallDurationHandler, instanceInterceptor.Handler, assetsCache.Handler, accessInterceptor.Handle)) apis.RegisterHandlerOnPrefix(assets.HandlerPrefix, assets.NewHandler(commands, verifier, config.InternalAuthZ, id.SonyFlakeGenerator(), store, queries, middleware.CallDurationHandler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.Handle))
userAgentInterceptor, err := middleware.NewUserAgentHandler(config.UserAgentCookie, keys.UserAgentCookieKey, id.SonyFlakeGenerator(), config.ExternalSecure, login.EndpointResources) userAgentInterceptor, err := middleware.NewUserAgentHandler(config.UserAgentCookie, keys.UserAgentCookieKey, id.SonyFlakeGenerator(), config.ExternalSecure, login.EndpointResources)
if err != nil { if err != nil {
@ -355,25 +362,25 @@ func startAPIs(
} }
apis.RegisterHandlerOnPrefix(openapi.HandlerPrefix, openAPIHandler) apis.RegisterHandlerOnPrefix(openapi.HandlerPrefix, openAPIHandler)
oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, accessInterceptor.Handle) oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor.Handle)
if err != nil { if err != nil {
return fmt.Errorf("unable to start oidc provider: %w", err) return fmt.Errorf("unable to start oidc provider: %w", err)
} }
apis.RegisterHandlerPrefixes(oidcProvider.HttpHandler(), "/.well-known/openid-configuration", "/oidc/v1", "/oauth/v2") apis.RegisterHandlerPrefixes(oidcProvider.HttpHandler(), "/.well-known/openid-configuration", "/oidc/v1", "/oauth/v2")
samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, accessInterceptor.Handle) samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor.Handle)
if err != nil { if err != nil {
return fmt.Errorf("unable to start saml provider: %w", err) return fmt.Errorf("unable to start saml provider: %w", err)
} }
apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler()) apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler())
c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, accessInterceptor.Handle, config.CustomerPortal) c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, nonLimitingAccessInterceptor.Handle, config.CustomerPortal)
if err != nil { if err != nil {
return fmt.Errorf("unable to start console: %w", err) return fmt.Errorf("unable to start console: %w", err)
} }
apis.RegisterHandlerOnPrefix(console.HandlerPrefix, c) apis.RegisterHandlerOnPrefix(console.HandlerPrefix, c)
l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, accessInterceptor.Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey) l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey)
if err != nil { if err != nil {
return fmt.Errorf("unable to start login: %w", err) return fmt.Errorf("unable to start login: %w", err)
} }

View File

@ -4,7 +4,7 @@ services:
traefik: traefik:
networks: networks:
- 'zitadel' - 'zitadel'
image: "traefik:v2.7" image: "traefik:v2.10.1"
ports: ports:
- "80:80" - "80:80"
- "443:443" - "443:443"

View File

@ -1,3 +1,8 @@
log:
level: DEBUG
accessLog: {}
entrypoints: entrypoints:
web: web:
address: ":80" address: ":80"
@ -7,7 +12,7 @@ entrypoints:
tls: tls:
stores: stores:
default: default:
# generates self-signed certificates # generates self-signed certificates
defaultCertificate: defaultCertificate:

View File

@ -23,3 +23,8 @@ Database:
RootCert: "/crdb-certs/ca.crt" RootCert: "/crdb-certs/ca.crt"
Cert: "/crdb-certs/client.root.crt" Cert: "/crdb-certs/client.root.crt"
Key: "/crdb-certs/client.root.key" Key: "/crdb-certs/client.root.key"
LogStore:
Access:
Stdout:
Enabled: true

View File

@ -70,3 +70,8 @@ This is the IAM admin users login according to your configuration in the [exampl
- **password**: *RootPassword1!* - **password**: *RootPassword1!*
Read more about [the login process](/guides/integrate/login-users). Read more about [the login process](/guides/integrate/login-users).
## Troubleshooting
You can connect to cockroach like this: `docker exec -it loadbalancing-example-my-cockroach-db-1 cockroach sql --host my-cockroach-db --certs-dir /cockroach/certs/`
For example, to show all login names: `docker exec -it loadbalancing-example-my-cockroach-db-1 cockroach sql --database zitadel --host my-cockroach-db --certs-dir /cockroach/certs/ --execute "select * from projections.login_names2"`

View File

@ -107,14 +107,9 @@ describe('quotas', () => {
}, },
}); });
}); });
expectCookieDoesntExist();
const expiresMax = new Date(); const expiresMax = new Date();
expiresMax.setMinutes(expiresMax.getMinutes() + 2); expiresMax.setMinutes(expiresMax.getMinutes() + 2);
cy.getCookie('zitadel.quota.limiting').then((cookie) => {
expect(cookie.value).to.equal('false');
const cookieExpiry = new Date();
cookieExpiry.setTime(cookie.expiry * 1000);
expect(cookieExpiry).to.be.within(start, expiresMax);
});
cy.request({ cy.request({
url: urls[0], url: urls[0],
method: 'GET', method: 'GET',
@ -127,12 +122,16 @@ describe('quotas', () => {
}); });
cy.getCookie('zitadel.quota.limiting').then((cookie) => { cy.getCookie('zitadel.quota.limiting').then((cookie) => {
expect(cookie.value).to.equal('true'); expect(cookie.value).to.equal('true');
const cookieExpiry = new Date();
cookieExpiry.setTime(cookie.expiry * 1000);
expect(cookieExpiry).to.be.within(start, expiresMax);
}); });
createHumanUser(ctx.api, testUserName, false).then((res) => { createHumanUser(ctx.api, testUserName, false).then((res) => {
expect(res.status).to.equal(429); expect(res.status).to.equal(429);
}); });
ensureQuotaIsRemoved(ctx, Unit.AuthenticatedRequests); ensureQuotaIsRemoved(ctx, Unit.AuthenticatedRequests);
createHumanUser(ctx.api, testUserName); createHumanUser(ctx.api, testUserName);
expectCookieDoesntExist();
}); });
}); });
}); });
@ -301,3 +300,9 @@ describe('quotas', () => {
}); });
}); });
}); });
function expectCookieDoesntExist() {
cy.getCookie('zitadel.quota.limiting').then((cookie) => {
expect(cookie).to.be.null;
});
}

View File

@ -34,6 +34,9 @@ type API struct {
http1HostName string http1HostName string
grpcGateway *server.Gateway grpcGateway *server.Gateway
healthServer *health.Server healthServer *health.Server
cookieHandler *http_util.CookieHandler
cookieConfig *http_mw.AccessConfig
queries *query.Queries
} }
type healthCheck interface { type healthCheck interface {
@ -49,6 +52,8 @@ func New(
authZ internal_authz.Config, authZ internal_authz.Config,
tlsConfig *tls.Config, http2HostName, http1HostName string, tlsConfig *tls.Config, http2HostName, http1HostName string,
accessSvc *logstore.Service, accessSvc *logstore.Service,
cookieHandler *http_util.CookieHandler,
cookieConfig *http_mw.AccessConfig,
) (_ *API, err error) { ) (_ *API, err error) {
api := &API{ api := &API{
port: port, port: port,
@ -56,10 +61,13 @@ func New(
health: queries, health: queries,
router: router, router: router,
http1HostName: http1HostName, http1HostName: http1HostName,
cookieConfig: cookieConfig,
cookieHandler: cookieHandler,
queries: queries,
} }
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc) api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc)
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName) api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, cookieHandler, cookieConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,7 +85,15 @@ func New(
// used for v1 api (system, admin, mgmt, auth) // used for v1 api (system, admin, mgmt, auth)
func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayPrefix) error { func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayPrefix) error {
grpcServer.RegisterServer(a.grpcServer) grpcServer.RegisterServer(a.grpcServer)
handler, prefix, err := server.CreateGatewayWithPrefix(ctx, grpcServer, a.port, a.http1HostName) handler, prefix, err := server.CreateGatewayWithPrefix(
ctx,
grpcServer,
a.port,
a.http1HostName,
a.cookieHandler,
a.cookieConfig,
a.queries,
)
if err != nil { if err != nil {
return err return err
} }

View File

@ -16,7 +16,9 @@ import (
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware" client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware" "github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/query"
) )
const ( const (
@ -67,10 +69,13 @@ type Gateway struct {
mux *runtime.ServeMux mux *runtime.ServeMux
http1HostName string http1HostName string
connection *grpc.ClientConn connection *grpc.ClientConn
cookieHandler *http_utils.CookieHandler
cookieConfig *http_mw.AccessConfig
queries *query.Queries
} }
func (g *Gateway) Handler() http.Handler { func (g *Gateway) Handler() http.Handler {
return addInterceptors(g.mux, g.http1HostName) return addInterceptors(g.mux, g.http1HostName, g.cookieHandler, g.cookieConfig, g.queries)
} }
type CustomHTTPResponse interface { type CustomHTTPResponse interface {
@ -79,7 +84,15 @@ type CustomHTTPResponse interface {
type RegisterGatewayFunc func(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error type RegisterGatewayFunc func(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error
func CreateGatewayWithPrefix(ctx context.Context, g WithGatewayPrefix, port uint16, http1HostName string) (http.Handler, string, error) { func CreateGatewayWithPrefix(
ctx context.Context,
g WithGatewayPrefix,
port uint16,
http1HostName string,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
queries *query.Queries,
) (http.Handler, string, error) {
runtimeMux := runtime.NewServeMux(serveMuxOptions...) runtimeMux := runtime.NewServeMux(serveMuxOptions...)
opts := []grpc.DialOption{ opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
@ -93,10 +106,10 @@ func CreateGatewayWithPrefix(ctx context.Context, g WithGatewayPrefix, port uint
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err) return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err)
} }
return addInterceptors(runtimeMux, http1HostName), g.GatewayPathPrefix(), nil return addInterceptors(runtimeMux, http1HostName, cookieHandler, cookieConfig, queries), g.GatewayPathPrefix(), nil
} }
func CreateGateway(ctx context.Context, port uint16, http1HostName string) (*Gateway, error) { func CreateGateway(ctx context.Context, port uint16, http1HostName string, cookieHandler *http_utils.CookieHandler, cookieConfig *http_mw.AccessConfig) (*Gateway, error) {
connection, err := dial(ctx, connection, err := dial(ctx,
port, port,
[]grpc.DialOption{ []grpc.DialOption{
@ -111,6 +124,8 @@ func CreateGateway(ctx context.Context, port uint16, http1HostName string) (*Gat
mux: runtimeMux, mux: runtimeMux,
http1HostName: http1HostName, http1HostName: http1HostName,
connection: connection, connection: connection,
cookieHandler: cookieHandler,
cookieConfig: cookieConfig,
}, nil }, nil
} }
@ -145,13 +160,23 @@ func dial(ctx context.Context, port uint16, opts []grpc.DialOption) (*grpc.Clien
return conn, nil return conn, nil
} }
func addInterceptors(handler http.Handler, http1HostName string) http.Handler { func addInterceptors(
handler http.Handler,
http1HostName string,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
queries *query.Queries,
) http.Handler {
handler = http_mw.CallDurationHandler(handler) handler = http_mw.CallDurationHandler(handler)
handler = http1Host(handler, http1HostName) handler = http1Host(handler, http1HostName)
handler = http_mw.CORSInterceptor(handler) handler = http_mw.CORSInterceptor(handler)
handler = http_mw.RobotsTagHandler(handler) handler = http_mw.RobotsTagHandler(handler)
handler = http_mw.DefaultTelemetryHandler(handler) handler = http_mw.DefaultTelemetryHandler(handler)
return http_mw.DefaultMetricsHandler(handler) // For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header
// only if it follows the http_mw.DefaultTelemetryHandler
handler = exhaustedCookieInterceptor(handler, cookieHandler, cookieConfig, queries)
handler = http_mw.DefaultMetricsHandler(handler)
return handler
} }
func http1Host(next http.Handler, http1HostName string) http.Handler { func http1Host(next http.Handler, http1HostName string) http.Handler {
@ -165,3 +190,38 @@ func http1Host(next http.Handler, http1HostName string) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func exhaustedCookieInterceptor(
next http.Handler,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
queries *query.Queries,
) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
next.ServeHTTP(&cookieResponseWriter{
ResponseWriter: writer,
cookieHandler: cookieHandler,
cookieConfig: cookieConfig,
request: request,
queries: queries,
}, request)
})
}
type cookieResponseWriter struct {
http.ResponseWriter
cookieHandler *http_utils.CookieHandler
cookieConfig *http_mw.AccessConfig
request *http.Request
queries *query.Queries
}
func (r *cookieResponseWriter) WriteHeader(status int) {
if status >= 200 && status < 300 {
http_mw.DeleteExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.request, r.cookieConfig)
}
if status == http.StatusTooManyRequests {
http_mw.SetExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.cookieConfig, r.request)
}
r.ResponseWriter.WriteHeader(status)
}

View File

@ -1,15 +1,16 @@
package middleware package middleware
import ( import (
"math" "net"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strings"
"time" "time"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http" http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/logstore" "github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/access" "github.com/zitadel/zitadel/internal/logstore/emitters/access"
@ -20,6 +21,7 @@ type AccessInterceptor struct {
svc *logstore.Service svc *logstore.Service
cookieHandler *http_utils.CookieHandler cookieHandler *http_utils.CookieHandler
limitConfig *AccessConfig limitConfig *AccessConfig
storeOnly bool
} }
type AccessConfig struct { type AccessConfig struct {
@ -27,14 +29,15 @@ type AccessConfig struct {
ExhaustedCookieMaxAge time.Duration ExhaustedCookieMaxAge time.Duration
} }
func NewAccessInterceptor(svc *logstore.Service, cookieConfig *AccessConfig) *AccessInterceptor { // NewAccessInterceptor intercepts all requests and stores them to the logstore.
// If storeOnly is false, it also checks if requests are exhausted.
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig, storeOnly bool) *AccessInterceptor {
return &AccessInterceptor{ return &AccessInterceptor{
svc: svc, svc: svc,
cookieHandler: http_utils.NewCookieHandler( cookieHandler: cookieHandler,
http_utils.WithUnsecure(), limitConfig: cookieConfig,
http_utils.WithMaxAge(int(math.Floor(cookieConfig.ExhaustedCookieMaxAge.Seconds()))), storeOnly: storeOnly,
),
limitConfig: cookieConfig,
} }
} }
@ -44,36 +47,33 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
} }
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context() ctx := request.Context()
var err error
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess") tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0} wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
instance := authz.GetInstance(ctx) instance := authz.GetInstance(ctx)
remaining := a.svc.Limit(tracingCtx, instance.InstanceID()) limit := false
limit := remaining != nil && *remaining == 0 if !a.storeOnly {
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
a.cookieHandler.SetCookie(wrappedWriter, a.limitConfig.ExhaustedCookieKey, request.Host, strconv.FormatBool(limit)) limit = remaining != nil && *remaining == 0
if limit {
wrappedWriter.WriteHeader(http.StatusTooManyRequests)
wrappedWriter.ignoreWrites = true
} }
checkSpan.End() checkSpan.End()
if limit {
next.ServeHTTP(wrappedWriter, request) // Limit can only be true when storeOnly is false, so set the cookie and the response code
SetExhaustedCookie(a.cookieHandler, wrappedWriter, a.limitConfig, request)
http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests)
} else {
if !a.storeOnly {
// If not limited and not storeOnly, ensure the cookie is deleted
DeleteExhaustedCookie(a.cookieHandler, wrappedWriter, request, a.limitConfig)
}
// Always serve if not limited
next.ServeHTTP(wrappedWriter, request)
}
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess") tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
defer writeSpan.End() defer writeSpan.End()
requestURL := request.RequestURI requestURL := request.RequestURI
unescapedURL, err := url.QueryUnescape(requestURL) unescapedURL, err := url.QueryUnescape(requestURL)
if err != nil { if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url") logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
// err = nil is effective because of deferred tracing span end
err = nil
} }
a.svc.Handle(tracingCtx, &access.Record{ a.svc.Handle(tracingCtx, &access.Record{
LogDate: time.Now(), LogDate: time.Now(),
@ -90,6 +90,24 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
}) })
} }
func SetExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, cookieConfig *AccessConfig, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
cookieHandler.SetCookie(writer, cookieConfig.ExhaustedCookieKey, domain, cookieValue)
}
func DeleteExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, request *http.Request, cookieConfig *AccessConfig) {
cookieHandler.DeleteCookie(writer, request, cookieConfig.ExhaustedCookieKey)
}
type statusRecorder struct { type statusRecorder struct {
http.ResponseWriter http.ResponseWriter
status int status int

View File

@ -1,9 +1,11 @@
package console package console
import ( import (
"bytes"
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"html/template"
"io/fs" "io/fs"
"net/http" "net/http"
"os" "os"
@ -22,8 +24,9 @@ import (
) )
type Config struct { type Config struct {
ShortCache middleware.CacheConfig ShortCache middleware.CacheConfig
LongCache middleware.CacheConfig LongCache middleware.CacheConfig
InstanceManagementURL string
} }
type spaHandler struct { type spaHandler struct {
@ -106,7 +109,13 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor) handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure) url := http_util.BuildOrigin(r.Host, externalSecure)
environmentJSON, err := createEnvironmentJSON(url, issuer(r), authz.GetInstance(r.Context()).ConsoleClientID(), customerPortal) instance := authz.GetInstance(r.Context())
instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance)
if err != nil {
http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError)
return
}
environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError)
return return
@ -118,6 +127,18 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
return handler, nil return handler, nil
} }
func templateInstanceManagementURL(templateableCookieValue string, instance authz.Instance) (string, error) {
cookieValueTemplate, err := template.New("cookievalue").Parse(templateableCookieValue)
if err != nil {
return templateableCookieValue, err
}
cookieValue := new(bytes.Buffer)
if err = cookieValueTemplate.Execute(cookieValue, instance); err != nil {
return templateableCookieValue, err
}
return cookieValue.String(), nil
}
func csp() *middleware.CSP { func csp() *middleware.CSP {
csp := middleware.DefaultSCP csp := middleware.DefaultSCP
csp.StyleSrc = csp.StyleSrc.AddInline() csp.StyleSrc = csp.StyleSrc.AddInline()
@ -127,17 +148,19 @@ func csp() *middleware.CSP {
return &csp return &csp
} }
func createEnvironmentJSON(api, issuer, clientID, customerPortal string) ([]byte, error) { func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl string) ([]byte, error) {
environment := struct { environment := struct {
API string `json:"api,omitempty"` API string `json:"api,omitempty"`
Issuer string `json:"issuer,omitempty"` Issuer string `json:"issuer,omitempty"`
ClientID string `json:"clientid,omitempty"` ClientID string `json:"clientid,omitempty"`
CustomerPortal string `json:"customer_portal,omitempty"` CustomerPortal string `json:"customer_portal,omitempty"`
InstanceManagementURL string `json:"instance_management_url,omitempty"`
}{ }{
API: api, API: api,
Issuer: issuer, Issuer: issuer,
ClientID: clientID, ClientID: clientID,
CustomerPortal: customerPortal, CustomerPortal: customerPortal,
InstanceManagementURL: instanceMgmtUrl,
} }
return json.Marshal(environment) return json.Marshal(environment)
} }