fix(api): return correct http code on assets api (#6388)

* fix(api): return correct http code on assets api

* add test

* fix test
This commit is contained in:
Livio Spring
2023-08-18 15:51:11 +02:00
committed by GitHub
parent 8b44794c75
commit 69b49ac0ed
3 changed files with 195 additions and 10 deletions

View File

@@ -42,7 +42,7 @@ func (h *Handler) Commands() *command.Commands {
}
func (h *Handler) ErrorHandler() ErrorHandler {
return DefaultErrorHandler
return h.errorHandler
}
func (h *Handler) Storage() static.Storage {
@@ -75,10 +75,14 @@ type Downloader interface {
ResourceOwner(ctx context.Context, ownerPath string) string
}
type ErrorHandler func(http.ResponseWriter, *http.Request, error, int)
type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error, defaultCode int)
func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, code int) {
func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, defaultCode int) {
logging.WithFields("uri", r.RequestURI).WithError(err).Warn("error occurred on asset api")
code, ok := http_util.ZitadelErrorToHTTPStatusCode(err)
if !ok {
code = defaultCode
}
http.Error(w, err.Error(), code)
}
@@ -162,7 +166,7 @@ func UploadHandleFunc(s AssetsService, uploader Uploader) func(http.ResponseWrit
}
err = uploader.UploadAsset(ctx, ctxData.OrgID, uploadInfo, s.Commands())
if err != nil {
s.ErrorHandler()(w, r, fmt.Errorf("upload failed: %v", err), http.StatusInternalServerError)
s.ErrorHandler()(w, r, fmt.Errorf("upload failed: %w", err), http.StatusInternalServerError)
return
}
}
@@ -190,10 +194,6 @@ func DownloadHandleFunc(s AssetsService, downloader Downloader) func(http.Respon
return
}
if err = GetAsset(w, r, resourceOwner, objectName, s.Storage()); err != nil {
if strings.Contains(err.Error(), "DATAB-pCP8P") {
s.ErrorHandler()(w, r, err, http.StatusNotFound)
return
}
s.ErrorHandler()(w, r, err, http.StatusInternalServerError)
}
}
@@ -206,11 +206,11 @@ func GetAsset(w http.ResponseWriter, r *http.Request, resourceOwner, objectName
}
data, getInfo, err := storage.GetObject(r.Context(), authz.GetInstance(r.Context()).InstanceID(), resourceOwner, objectName)
if err != nil {
return fmt.Errorf("download failed: %v", err)
return fmt.Errorf("download failed: %w", err)
}
info, err := getInfo()
if err != nil {
return fmt.Errorf("download failed: %v", err)
return fmt.Errorf("download failed: %w", err)
}
if info.Hash == strings.Trim(r.Header.Get(http_util.IfNoneMatch), "\"") {
w.Header().Set(http_util.LastModified, info.LastModified.Format(time.RFC1123))