From 69b49ac0edd117bc44e2f065f6ea63c162551c84 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Fri, 18 Aug 2023 15:51:11 +0200 Subject: [PATCH] fix(api): return correct http code on assets api (#6388) * fix(api): return correct http code on assets api * add test * fix test --- internal/api/assets/asset.go | 20 ++--- internal/api/http/error.go | 47 +++++++++++ internal/api/http/error_test.go | 138 ++++++++++++++++++++++++++++++++ 3 files changed, 195 insertions(+), 10 deletions(-) create mode 100644 internal/api/http/error.go create mode 100644 internal/api/http/error_test.go diff --git a/internal/api/assets/asset.go b/internal/api/assets/asset.go index 08ba7130a9..0714348a90 100644 --- a/internal/api/assets/asset.go +++ b/internal/api/assets/asset.go @@ -42,7 +42,7 @@ func (h *Handler) Commands() *command.Commands { } func (h *Handler) ErrorHandler() ErrorHandler { - return DefaultErrorHandler + return h.errorHandler } func (h *Handler) Storage() static.Storage { @@ -75,10 +75,14 @@ type Downloader interface { ResourceOwner(ctx context.Context, ownerPath string) string } -type ErrorHandler func(http.ResponseWriter, *http.Request, error, int) +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error, defaultCode int) -func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, code int) { +func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, defaultCode int) { logging.WithFields("uri", r.RequestURI).WithError(err).Warn("error occurred on asset api") + code, ok := http_util.ZitadelErrorToHTTPStatusCode(err) + if !ok { + code = defaultCode + } http.Error(w, err.Error(), code) } @@ -162,7 +166,7 @@ func UploadHandleFunc(s AssetsService, uploader Uploader) func(http.ResponseWrit } err = uploader.UploadAsset(ctx, ctxData.OrgID, uploadInfo, s.Commands()) if err != nil { - s.ErrorHandler()(w, r, fmt.Errorf("upload failed: %v", err), http.StatusInternalServerError) + s.ErrorHandler()(w, r, fmt.Errorf("upload failed: %w", err), http.StatusInternalServerError) return } } @@ -190,10 +194,6 @@ func DownloadHandleFunc(s AssetsService, downloader Downloader) func(http.Respon return } if err = GetAsset(w, r, resourceOwner, objectName, s.Storage()); err != nil { - if strings.Contains(err.Error(), "DATAB-pCP8P") { - s.ErrorHandler()(w, r, err, http.StatusNotFound) - return - } s.ErrorHandler()(w, r, err, http.StatusInternalServerError) } } @@ -206,11 +206,11 @@ func GetAsset(w http.ResponseWriter, r *http.Request, resourceOwner, objectName } data, getInfo, err := storage.GetObject(r.Context(), authz.GetInstance(r.Context()).InstanceID(), resourceOwner, objectName) if err != nil { - return fmt.Errorf("download failed: %v", err) + return fmt.Errorf("download failed: %w", err) } info, err := getInfo() if err != nil { - return fmt.Errorf("download failed: %v", err) + return fmt.Errorf("download failed: %w", err) } if info.Hash == strings.Trim(r.Header.Get(http_util.IfNoneMatch), "\"") { w.Header().Set(http_util.LastModified, info.LastModified.Format(time.RFC1123)) diff --git a/internal/api/http/error.go b/internal/api/http/error.go new file mode 100644 index 0000000000..e73f1def96 --- /dev/null +++ b/internal/api/http/error.go @@ -0,0 +1,47 @@ +package http + +import ( + "errors" + "net/http" + + caos_errs "github.com/zitadel/zitadel/internal/errors" +) + +func ZitadelErrorToHTTPStatusCode(err error) (statusCode int, ok bool) { + if err == nil { + return http.StatusOK, true + } + //nolint:errorlint + switch err.(type) { + case *caos_errs.AlreadyExistsError: + return http.StatusConflict, true + case *caos_errs.DeadlineExceededError: + return http.StatusGatewayTimeout, true + case *caos_errs.InternalError: + return http.StatusInternalServerError, true + case *caos_errs.InvalidArgumentError: + return http.StatusBadRequest, true + case *caos_errs.NotFoundError: + return http.StatusNotFound, true + case *caos_errs.PermissionDeniedError: + return http.StatusForbidden, true + case *caos_errs.PreconditionFailedError: + // use the same code as grpc-gateway: + // https://github.com/grpc-ecosystem/grpc-gateway/blob/9e33e38f15cb7d2f11096366e62ea391a3459ba9/runtime/errors.go#L59 + return http.StatusBadRequest, true + case *caos_errs.UnauthenticatedError: + return http.StatusUnauthorized, true + case *caos_errs.UnavailableError: + return http.StatusServiceUnavailable, true + case *caos_errs.UnimplementedError: + return http.StatusNotImplemented, true + case *caos_errs.ResourceExhaustedError: + return http.StatusTooManyRequests, true + default: + c := new(caos_errs.CaosError) + if errors.As(err, &c) { + return ZitadelErrorToHTTPStatusCode(errors.Unwrap(err)) + } + return http.StatusInternalServerError, false + } +} diff --git a/internal/api/http/error_test.go b/internal/api/http/error_test.go new file mode 100644 index 0000000000..08fed349d2 --- /dev/null +++ b/internal/api/http/error_test.go @@ -0,0 +1,138 @@ +package http + +import ( + "errors" + "fmt" + "net/http" + "testing" + + caos_errors "github.com/zitadel/zitadel/internal/errors" +) + +func TestZitadelErrorToHTTPStatusCode(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + wantStatusCode int + wantOk bool + }{ + { + name: "no error", + args: args{ + err: nil, + }, + wantStatusCode: http.StatusOK, + wantOk: true, + }, + { + name: "wrapped already exists", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowAlreadyExists(nil, "id", "message")), + }, + wantStatusCode: http.StatusConflict, + wantOk: true, + }, + { + name: "wrapped deadline exceeded", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowDeadlineExceeded(nil, "id", "message")), + }, + wantStatusCode: http.StatusGatewayTimeout, + wantOk: true, + }, + { + name: "wrapped internal", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowInternal(nil, "id", "message")), + }, + wantStatusCode: http.StatusInternalServerError, + wantOk: true, + }, + { + name: "wrapped invalid argument", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowInvalidArgument(nil, "id", "message")), + }, + wantStatusCode: http.StatusBadRequest, + wantOk: true, + }, + { + name: "wrapped not found", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowNotFound(nil, "id", "message")), + }, + wantStatusCode: http.StatusNotFound, + wantOk: true, + }, + { + name: "wrapped permission denied", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowPermissionDenied(nil, "id", "message")), + }, + wantStatusCode: http.StatusForbidden, + wantOk: true, + }, + { + name: "wrapped precondition failed", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowPreconditionFailed(nil, "id", "message")), + }, + wantStatusCode: http.StatusBadRequest, + wantOk: true, + }, + { + name: "wrapped unauthenticated", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowUnauthenticated(nil, "id", "message")), + }, + wantStatusCode: http.StatusUnauthorized, + wantOk: true, + }, + { + name: "wrapped unavailable", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowUnavailable(nil, "id", "message")), + }, + wantStatusCode: http.StatusServiceUnavailable, + wantOk: true, + }, + { + name: "wrapped unimplemented", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowUnimplemented(nil, "id", "message")), + }, + wantStatusCode: http.StatusNotImplemented, + wantOk: true, + }, + { + name: "wrapped resource exhausted", + args: args{ + err: fmt.Errorf("wrapped %w", caos_errors.ThrowResourceExhausted(nil, "id", "message")), + }, + wantStatusCode: http.StatusTooManyRequests, + wantOk: true, + }, + { + name: "no caos/zitadel error", + args: args{ + err: errors.New("error"), + }, + wantStatusCode: http.StatusInternalServerError, + wantOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotStatusCode, gotOk := ZitadelErrorToHTTPStatusCode(tt.args.err) + if gotStatusCode != tt.wantStatusCode { + t.Errorf("ZitadelErrorToHTTPStatusCode() gotStatusCode = %v, want %v", gotStatusCode, tt.wantStatusCode) + } + if gotOk != tt.wantOk { + t.Errorf("ZitadelErrorToHTTPStatusCode() gotOk = %v, want %v", gotOk, tt.wantOk) + } + }) + } +}