From df8f02db3f4929060f6baae848aa1cb73f8b4413 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 7 Feb 2022 10:01:23 -0800 Subject: [PATCH] tsweb: add gzip support to JSONHandlerFunc Change-Id: I337e05f92f744bfc7e9d6fb8e67c87c191ba4da8 Signed-off-by: Brad Fitzpatrick --- tsweb/jsonhandler.go | 78 ++++++++++++++++++++++++++++++++- tsweb/jsonhandler_test.go | 91 ++++++++++++++++++++++++++++++++------- 2 files changed, 151 insertions(+), 18 deletions(-) diff --git a/tsweb/jsonhandler.go b/tsweb/jsonhandler.go index 9b3a0378b..89e96e89d 100644 --- a/tsweb/jsonhandler.go +++ b/tsweb/jsonhandler.go @@ -5,9 +5,17 @@ package tsweb import ( + "bytes" + "compress/gzip" "encoding/json" "fmt" + "io/ioutil" "net/http" + "strconv" + "strings" + "sync" + + "go4.org/mem" ) type response struct { @@ -85,7 +93,73 @@ func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request return jerr } - w.WriteHeader(status) - w.Write(b) + if AcceptsEncoding(r, "gzip") { + encb, err := gzipBytes(b) + if err != nil { + return err + } + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Content-Length", strconv.Itoa(len(encb))) + w.Write(encb) + } else { + w.Header().Set("Content-Length", strconv.Itoa(len(b))) + w.WriteHeader(status) + w.Write(b) + } return err } + +var gzWriterPool sync.Pool // of *gzip.Writer + +// gzipBytes returns the gzipped encoding of b. +func gzipBytes(b []byte) (zb []byte, err error) { + var buf bytes.Buffer + zw, ok := gzWriterPool.Get().(*gzip.Writer) + if ok { + zw.Reset(&buf) + } else { + zw = gzip.NewWriter(&buf) + } + defer gzWriterPool.Put(zw) + if _, err := zw.Write(b); err != nil { + return nil, err + } + if err := zw.Close(); err != nil { + return nil, err + } + zb = buf.Bytes() + zw.Reset(ioutil.Discard) + return zb, nil +} + +// AcceptsEncoding reports whether r accepts the named encoding +// ("gzip", "br", etc). +func AcceptsEncoding(r *http.Request, enc string) bool { + h := r.Header.Get("Accept-Encoding") + if h == "" { + return false + } + if !strings.Contains(h, enc) && !mem.ContainsFold(mem.S(h), mem.S(enc)) { + return false + } + remain := h + for len(remain) > 0 { + comma := strings.Index(remain, ",") + var part string + if comma == -1 { + part = remain + remain = "" + } else { + part = remain[:comma] + remain = remain[comma+1:] + } + part = strings.TrimSpace(part) + if i := strings.Index(part, ";"); i != -1 { + part = part[:i] + } + if part == enc { + return true + } + } + return false +} diff --git a/tsweb/jsonhandler_test.go b/tsweb/jsonhandler_test.go index b36d55d89..8cf55c2ab 100644 --- a/tsweb/jsonhandler_test.go +++ b/tsweb/jsonhandler_test.go @@ -5,8 +5,11 @@ package tsweb import ( + "bytes" + "compress/gzip" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -27,13 +30,25 @@ type Response struct { } func TestNewJSONHandler(t *testing.T) { - checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response { + checkStatus := func(t *testing.T, w *httptest.ResponseRecorder, status string, code int) *Response { d := &Response{ Data: &Data{}, } - t.Logf("%s", w.Body.Bytes()) - err := json.Unmarshal(w.Body.Bytes(), d) + bodyBytes := w.Body.Bytes() + if w.Result().Header.Get("Content-Encoding") == "gzip" { + zr, err := gzip.NewReader(bytes.NewReader(bodyBytes)) + if err != nil { + t.Fatalf("gzip read error at start: %v", err) + } + bodyBytes, err = io.ReadAll(zr) + if err != nil { + t.Fatalf("gzip read error: %v", err) + } + } + + t.Logf("%s", bodyBytes) + err := json.Unmarshal(bodyBytes, d) if err != nil { t.Logf(err.Error()) return nil @@ -64,7 +79,7 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h21.ServeHTTPReturn(w, r) - checkStatus(w, "success", http.StatusOK) + checkStatus(t, w, "success", http.StatusOK) }) t.Run("403 HTTPError", func(t *testing.T) { @@ -75,7 +90,7 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h.ServeHTTPReturn(w, r) - checkStatus(w, "error", http.StatusForbidden) + checkStatus(t, w, "error", http.StatusForbidden) }) h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { @@ -86,7 +101,7 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h22.ServeHTTPReturn(w, r) - checkStatus(w, "success", http.StatusOK) + checkStatus(t, w, "success", http.StatusOK) }) h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { @@ -105,21 +120,21 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) h31.ServeHTTPReturn(w, r) - checkStatus(w, "success", http.StatusOK) + checkStatus(t, w, "success", http.StatusOK) }) t.Run("400 bad json", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) h31.ServeHTTPReturn(w, r) - checkStatus(w, "error", http.StatusBadRequest) + checkStatus(t, w, "error", http.StatusBadRequest) }) t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h31.ServeHTTPReturn(w, r) - resp := checkStatus(w, "error", http.StatusBadRequest) + resp := checkStatus(t, w, "error", http.StatusBadRequest) if resp.Error != "name is empty" { t.Fatalf("wrong error") } @@ -144,7 +159,23 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) h32.ServeHTTPReturn(w, r) - resp := checkStatus(w, "success", http.StatusOK) + resp := checkStatus(t, w, "success", http.StatusOK) + t.Log(resp.Data) + if resp.Data.Price != 20 { + t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) + } + }) + + t.Run("gzipped", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) + r.Header.Set("Accept-Encoding", "gzip") + h32.ServeHTTPReturn(w, r) + res := w.Result() + if ct := res.Header.Get("Content-Encoding"); ct != "gzip" { + t.Fatalf("encoding = %q; want gzip", ct) + } + resp := checkStatus(t, w, "success", http.StatusOK) t.Log(resp.Data) if resp.Data.Price != 20 { t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) @@ -155,7 +186,7 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h32.ServeHTTPReturn(w, r) - resp := checkStatus(w, "error", http.StatusBadRequest) + resp := checkStatus(t, w, "error", http.StatusBadRequest) if resp.Error != "price is empty" { t.Fatalf("wrong error") } @@ -165,7 +196,7 @@ func TestNewJSONHandler(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`)) h32.ServeHTTPReturn(w, r) - resp := checkStatus(w, "error", http.StatusInternalServerError) + resp := checkStatus(t, w, "error", http.StatusInternalServerError) if resp.Error != "internal server error" { t.Fatalf("wrong error") } @@ -177,7 +208,7 @@ func TestNewJSONHandler(t *testing.T) { JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return http.StatusOK, make(chan int), nil }).ServeHTTPReturn(w, r) - resp := checkStatus(w, "error", http.StatusInternalServerError) + resp := checkStatus(t, w, "error", http.StatusInternalServerError) if resp.Error != "json marshal error" { t.Fatalf("wrong error") } @@ -189,7 +220,7 @@ func TestNewJSONHandler(t *testing.T) { JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) { return }).ServeHTTPReturn(w, r) - checkStatus(w, "error", http.StatusInternalServerError) + checkStatus(t, w, "error", http.StatusInternalServerError) }) t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) { @@ -203,7 +234,7 @@ func TestNewJSONHandler(t *testing.T) { Data: &Data{}, Error: "403 forbidden", } - got := checkStatus(w, "error", http.StatusForbidden) + got := checkStatus(t, w, "error", http.StatusForbidden) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf(diff) } @@ -223,9 +254,37 @@ func TestNewJSONHandler(t *testing.T) { Data: &Data{}, Error: "403 forbidden", } - got := checkStatus(w, "error", http.StatusForbidden) + got := checkStatus(t, w, "error", http.StatusForbidden) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("(-want,+got):\n%s", diff) } }) } + +func TestAcceptsEncoding(t *testing.T) { + tests := []struct { + in, enc string + want bool + }{ + {"", "gzip", false}, + {"gzip", "gzip", true}, + {"foo,gzip", "gzip", true}, + {"foo, gzip", "gzip", true}, + {"foo, gzip ", "gzip", true}, + {"gzip, foo ", "gzip", true}, + {"gzip, foo ", "br", false}, + {"gzip, foo ", "fo", false}, + {"gzip;q=1.2, foo ", "gzip", true}, + {" gzip;q=1.2, foo ", "gzip", true}, + } + for i, tt := range tests { + h := make(http.Header) + if tt.in != "" { + h.Set("Accept-Encoding", tt.in) + } + got := AcceptsEncoding(&http.Request{Header: h}, tt.enc) + if got != tt.want { + t.Errorf("%d. got %v; want %v", i, got, tt.want) + } + } +}