mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 08:07:42 +00:00
tsweb: add gzip support to JSONHandlerFunc
Change-Id: I337e05f92f744bfc7e9d6fb8e67c87c191ba4da8 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
16652ae52c
commit
df8f02db3f
@ -5,9 +5,17 @@
|
|||||||
package tsweb
|
package tsweb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"go4.org/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
type response struct {
|
type response struct {
|
||||||
@ -85,7 +93,73 @@ func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request
|
|||||||
return jerr
|
return jerr
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(status)
|
if AcceptsEncoding(r, "gzip") {
|
||||||
w.Write(b)
|
encb, err := gzipBytes(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(encb)))
|
||||||
|
w.Write(encb)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(b)))
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write(b)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var gzWriterPool sync.Pool // of *gzip.Writer
|
||||||
|
|
||||||
|
// gzipBytes returns the gzipped encoding of b.
|
||||||
|
func gzipBytes(b []byte) (zb []byte, err error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
zw, ok := gzWriterPool.Get().(*gzip.Writer)
|
||||||
|
if ok {
|
||||||
|
zw.Reset(&buf)
|
||||||
|
} else {
|
||||||
|
zw = gzip.NewWriter(&buf)
|
||||||
|
}
|
||||||
|
defer gzWriterPool.Put(zw)
|
||||||
|
if _, err := zw.Write(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := zw.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
zb = buf.Bytes()
|
||||||
|
zw.Reset(ioutil.Discard)
|
||||||
|
return zb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptsEncoding reports whether r accepts the named encoding
|
||||||
|
// ("gzip", "br", etc).
|
||||||
|
func AcceptsEncoding(r *http.Request, enc string) bool {
|
||||||
|
h := r.Header.Get("Accept-Encoding")
|
||||||
|
if h == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !strings.Contains(h, enc) && !mem.ContainsFold(mem.S(h), mem.S(enc)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remain := h
|
||||||
|
for len(remain) > 0 {
|
||||||
|
comma := strings.Index(remain, ",")
|
||||||
|
var part string
|
||||||
|
if comma == -1 {
|
||||||
|
part = remain
|
||||||
|
remain = ""
|
||||||
|
} else {
|
||||||
|
part = remain[:comma]
|
||||||
|
remain = remain[comma+1:]
|
||||||
|
}
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if i := strings.Index(part, ";"); i != -1 {
|
||||||
|
part = part[:i]
|
||||||
|
}
|
||||||
|
if part == enc {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -5,8 +5,11 @@
|
|||||||
package tsweb
|
package tsweb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@ -27,13 +30,25 @@ type Response struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNewJSONHandler(t *testing.T) {
|
func TestNewJSONHandler(t *testing.T) {
|
||||||
checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response {
|
checkStatus := func(t *testing.T, w *httptest.ResponseRecorder, status string, code int) *Response {
|
||||||
d := &Response{
|
d := &Response{
|
||||||
Data: &Data{},
|
Data: &Data{},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("%s", w.Body.Bytes())
|
bodyBytes := w.Body.Bytes()
|
||||||
err := json.Unmarshal(w.Body.Bytes(), d)
|
if w.Result().Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
zr, err := gzip.NewReader(bytes.NewReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gzip read error at start: %v", err)
|
||||||
|
}
|
||||||
|
bodyBytes, err = io.ReadAll(zr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gzip read error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("%s", bodyBytes)
|
||||||
|
err := json.Unmarshal(bodyBytes, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf(err.Error())
|
t.Logf(err.Error())
|
||||||
return nil
|
return nil
|
||||||
@ -64,7 +79,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
h21.ServeHTTPReturn(w, r)
|
h21.ServeHTTPReturn(w, r)
|
||||||
checkStatus(w, "success", http.StatusOK)
|
checkStatus(t, w, "success", http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("403 HTTPError", func(t *testing.T) {
|
t.Run("403 HTTPError", func(t *testing.T) {
|
||||||
@ -75,7 +90,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
h.ServeHTTPReturn(w, r)
|
h.ServeHTTPReturn(w, r)
|
||||||
checkStatus(w, "error", http.StatusForbidden)
|
checkStatus(t, w, "error", http.StatusForbidden)
|
||||||
})
|
})
|
||||||
|
|
||||||
h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
@ -86,7 +101,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
h22.ServeHTTPReturn(w, r)
|
h22.ServeHTTPReturn(w, r)
|
||||||
checkStatus(w, "success", http.StatusOK)
|
checkStatus(t, w, "success", http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
@ -105,21 +120,21 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
|
||||||
h31.ServeHTTPReturn(w, r)
|
h31.ServeHTTPReturn(w, r)
|
||||||
checkStatus(w, "success", http.StatusOK)
|
checkStatus(t, w, "success", http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("400 bad json", func(t *testing.T) {
|
t.Run("400 bad json", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
|
||||||
h31.ServeHTTPReturn(w, r)
|
h31.ServeHTTPReturn(w, r)
|
||||||
checkStatus(w, "error", http.StatusBadRequest)
|
checkStatus(t, w, "error", http.StatusBadRequest)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("400 post data error", func(t *testing.T) {
|
t.Run("400 post data error", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||||
h31.ServeHTTPReturn(w, r)
|
h31.ServeHTTPReturn(w, r)
|
||||||
resp := checkStatus(w, "error", http.StatusBadRequest)
|
resp := checkStatus(t, w, "error", http.StatusBadRequest)
|
||||||
if resp.Error != "name is empty" {
|
if resp.Error != "name is empty" {
|
||||||
t.Fatalf("wrong error")
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
@ -144,7 +159,23 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
||||||
h32.ServeHTTPReturn(w, r)
|
h32.ServeHTTPReturn(w, r)
|
||||||
resp := checkStatus(w, "success", http.StatusOK)
|
resp := checkStatus(t, w, "success", http.StatusOK)
|
||||||
|
t.Log(resp.Data)
|
||||||
|
if resp.Data.Price != 20 {
|
||||||
|
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gzipped", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
||||||
|
r.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
h32.ServeHTTPReturn(w, r)
|
||||||
|
res := w.Result()
|
||||||
|
if ct := res.Header.Get("Content-Encoding"); ct != "gzip" {
|
||||||
|
t.Fatalf("encoding = %q; want gzip", ct)
|
||||||
|
}
|
||||||
|
resp := checkStatus(t, w, "success", http.StatusOK)
|
||||||
t.Log(resp.Data)
|
t.Log(resp.Data)
|
||||||
if resp.Data.Price != 20 {
|
if resp.Data.Price != 20 {
|
||||||
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
||||||
@ -155,7 +186,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||||
h32.ServeHTTPReturn(w, r)
|
h32.ServeHTTPReturn(w, r)
|
||||||
resp := checkStatus(w, "error", http.StatusBadRequest)
|
resp := checkStatus(t, w, "error", http.StatusBadRequest)
|
||||||
if resp.Error != "price is empty" {
|
if resp.Error != "price is empty" {
|
||||||
t.Fatalf("wrong error")
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
@ -165,7 +196,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
|
||||||
h32.ServeHTTPReturn(w, r)
|
h32.ServeHTTPReturn(w, r)
|
||||||
resp := checkStatus(w, "error", http.StatusInternalServerError)
|
resp := checkStatus(t, w, "error", http.StatusInternalServerError)
|
||||||
if resp.Error != "internal server error" {
|
if resp.Error != "internal server error" {
|
||||||
t.Fatalf("wrong error")
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
@ -177,7 +208,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
return http.StatusOK, make(chan int), nil
|
return http.StatusOK, make(chan int), nil
|
||||||
}).ServeHTTPReturn(w, r)
|
}).ServeHTTPReturn(w, r)
|
||||||
resp := checkStatus(w, "error", http.StatusInternalServerError)
|
resp := checkStatus(t, w, "error", http.StatusInternalServerError)
|
||||||
if resp.Error != "json marshal error" {
|
if resp.Error != "json marshal error" {
|
||||||
t.Fatalf("wrong error")
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
@ -189,7 +220,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
|
JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
|
||||||
return
|
return
|
||||||
}).ServeHTTPReturn(w, r)
|
}).ServeHTTPReturn(w, r)
|
||||||
checkStatus(w, "error", http.StatusInternalServerError)
|
checkStatus(t, w, "error", http.StatusInternalServerError)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) {
|
t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) {
|
||||||
@ -203,7 +234,7 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
Data: &Data{},
|
Data: &Data{},
|
||||||
Error: "403 forbidden",
|
Error: "403 forbidden",
|
||||||
}
|
}
|
||||||
got := checkStatus(w, "error", http.StatusForbidden)
|
got := checkStatus(t, w, "error", http.StatusForbidden)
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
t.Fatalf(diff)
|
t.Fatalf(diff)
|
||||||
}
|
}
|
||||||
@ -223,9 +254,37 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
Data: &Data{},
|
Data: &Data{},
|
||||||
Error: "403 forbidden",
|
Error: "403 forbidden",
|
||||||
}
|
}
|
||||||
got := checkStatus(w, "error", http.StatusForbidden)
|
got := checkStatus(t, w, "error", http.StatusForbidden)
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
t.Fatalf("(-want,+got):\n%s", diff)
|
t.Fatalf("(-want,+got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAcceptsEncoding(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in, enc string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"", "gzip", false},
|
||||||
|
{"gzip", "gzip", true},
|
||||||
|
{"foo,gzip", "gzip", true},
|
||||||
|
{"foo, gzip", "gzip", true},
|
||||||
|
{"foo, gzip ", "gzip", true},
|
||||||
|
{"gzip, foo ", "gzip", true},
|
||||||
|
{"gzip, foo ", "br", false},
|
||||||
|
{"gzip, foo ", "fo", false},
|
||||||
|
{"gzip;q=1.2, foo ", "gzip", true},
|
||||||
|
{" gzip;q=1.2, foo ", "gzip", true},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
h := make(http.Header)
|
||||||
|
if tt.in != "" {
|
||||||
|
h.Set("Accept-Encoding", tt.in)
|
||||||
|
}
|
||||||
|
got := AcceptsEncoding(&http.Request{Header: h}, tt.enc)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("%d. got %v; want %v", i, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user