mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:07:31 +00:00
fix: grpc gateway interceptors (#3767)
This commit is contained in:
@@ -14,6 +14,10 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP1Host = "x-zitadel-http1-host"
|
||||
)
|
||||
|
||||
type InstanceVerifier interface {
|
||||
GetInstance(ctx context.Context)
|
||||
}
|
||||
@@ -36,7 +40,7 @@ func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
|
||||
}
|
||||
}
|
||||
|
||||
host, err := hostNameFromContext(interceptorCtx, headerName)
|
||||
host, err := hostFromContext(interceptorCtx, headerName)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
@@ -48,12 +52,19 @@ func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
|
||||
return handler(authz.WithInstance(ctx, instance), req)
|
||||
}
|
||||
|
||||
func hostNameFromContext(ctx context.Context, headerName string) (string, error) {
|
||||
func hostFromContext(ctx context.Context, headerName string) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot read metadata")
|
||||
}
|
||||
host, ok := md[headerName]
|
||||
host, ok := md[HTTP1Host]
|
||||
if ok && len(host) == 1 {
|
||||
if !isAllowedToSendHTTP1Header(md) {
|
||||
return "", fmt.Errorf("no valid host header")
|
||||
}
|
||||
return host[0], nil
|
||||
}
|
||||
host, ok = md[headerName]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot find header: %v", headerName)
|
||||
}
|
||||
@@ -62,3 +73,11 @@ func hostNameFromContext(ctx context.Context, headerName string) (string, error)
|
||||
}
|
||||
return host[0], nil
|
||||
}
|
||||
|
||||
//isAllowedToSendHTTP1Header check if the gRPC call was sent to `localhost`
|
||||
//this is only possible when calling the server directly running on localhost
|
||||
//or through the gRPC gateway
|
||||
func isAllowedToSendHTTP1Header(md metadata.MD) bool {
|
||||
authority, ok := md[":authority"]
|
||||
return ok && len(authority) == 1 && strings.Split(authority[0], ":")[0] == "localhost"
|
||||
}
|
||||
|
@@ -63,13 +63,13 @@ func Test_hostNameFromContext(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := hostNameFromContext(tt.args.ctx, tt.args.headerName)
|
||||
got, err := hostFromContext(tt.args.ctx, tt.args.headerName)
|
||||
if (err != nil) != tt.res.err {
|
||||
t.Errorf("hostNameFromContext() error = %v, wantErr %v", err, tt.res.err)
|
||||
t.Errorf("hostFromContext() error = %v, wantErr %v", err, tt.res.err)
|
||||
return
|
||||
}
|
||||
if got != tt.res.want {
|
||||
t.Errorf("hostNameFromContext() got = %v, want %v", got, tt.res.want)
|
||||
t.Errorf("hostFromContext() got = %v, want %v", got, tt.res.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user