mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-18 02:48:40 +00:00
tsweb: rewrite JSONHandler without using reflect (#684)
Closes #656 #657 Signed-off-by: Zijie Lu <zijie@tailscale.com>
This commit is contained in:
parent
93ffc565e5
commit
1835bb6f85
@ -7,7 +7,6 @@ package tsweb
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type response struct {
|
type response struct {
|
||||||
@ -16,119 +15,59 @@ type response struct {
|
|||||||
Data interface{} `json:"data,omitempty"`
|
Data interface{} `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseSuccess(data interface{}) *response {
|
// TODO: Header
|
||||||
return &response{
|
|
||||||
Status: "success",
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseError(e string) *response {
|
// JSONHandlerFunc only take *http.Request as argument to avoid any misuse of http.ResponseWriter.
|
||||||
return &response{
|
// The function's results must be (status int, data interface{}, err error).
|
||||||
Status: "error",
|
// Return a HTTPError to show an error message, otherwise JSONHandler will only show "internal server error".
|
||||||
Error: e,
|
type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeResponse(w http.ResponseWriter, s int, resp *response) {
|
// ServeHTTP calls the JSONHandlerFunc and automatically marshals http responses.
|
||||||
b, _ := json.Marshal(resp)
|
//
|
||||||
|
// Use the following code to unmarshal the request body
|
||||||
|
// body := new(DataType)
|
||||||
|
// if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||||
|
// return http.StatusBadRequest, nil, err
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Check jsonhandler_text.go for examples
|
||||||
|
func (fn JSONHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(s)
|
var resp *response
|
||||||
|
status, data, err := fn(r)
|
||||||
|
if status == 0 {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
resp = &response{
|
||||||
|
Status: "error",
|
||||||
|
Error: "internal server error",
|
||||||
|
}
|
||||||
|
} else if err == nil {
|
||||||
|
resp = &response{
|
||||||
|
Status: "success",
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if werr, ok := err.(HTTPError); ok {
|
||||||
|
resp = &response{
|
||||||
|
Status: "error",
|
||||||
|
Error: werr.Msg,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resp = &response{
|
||||||
|
Status: "error",
|
||||||
|
Error: "internal server error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"status":"error","error":"json marshal error"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(status)
|
||||||
w.Write(b)
|
w.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkFn(t reflect.Type) {
|
|
||||||
h := reflect.TypeOf(http.HandlerFunc(nil))
|
|
||||||
switch t.NumIn() {
|
|
||||||
case 2, 3:
|
|
||||||
if !t.In(0).AssignableTo(h.In(0)) {
|
|
||||||
panic("first argument must be http.ResponseWriter")
|
|
||||||
}
|
|
||||||
if !t.In(1).AssignableTo(h.In(1)) {
|
|
||||||
panic("second argument must be *http.Request")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
panic("JSONHandler: number of input parameter should be 2 or 3")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t.NumOut() {
|
|
||||||
case 1:
|
|
||||||
if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
||||||
panic("return value must be error")
|
|
||||||
}
|
|
||||||
case 2:
|
|
||||||
if !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
||||||
panic("second return value must be error")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
panic("JSONHandler: number of return values should be 1 or 2")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSONHandler wraps an HTTP handler function with a version that automatically
|
|
||||||
// unmarshals and marshals requests and responses respectively into fn's arguments
|
|
||||||
// and results.
|
|
||||||
//
|
|
||||||
// The fn parameter is a function. It must take two or three input arguments.
|
|
||||||
// The first two arguments must be http.ResponseWriter and *http.Request.
|
|
||||||
// The optional third argument can be of any type representing the JSON input.
|
|
||||||
// The function's results can be either (error) or (T, error), where T is the
|
|
||||||
// JSON-marshalled result type.
|
|
||||||
//
|
|
||||||
// For example:
|
|
||||||
// fn := func(w http.ResponseWriter, r *http.Request, in *Req) (*Res, error) { ... }
|
|
||||||
func JSONHandler(fn interface{}) http.Handler {
|
|
||||||
v := reflect.ValueOf(fn)
|
|
||||||
t := v.Type()
|
|
||||||
checkFn(t)
|
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
wv := reflect.ValueOf(w)
|
|
||||||
rv := reflect.ValueOf(r)
|
|
||||||
var vs []reflect.Value
|
|
||||||
|
|
||||||
switch t.NumIn() {
|
|
||||||
case 2:
|
|
||||||
vs = v.Call([]reflect.Value{wv, rv})
|
|
||||||
case 3:
|
|
||||||
dv := reflect.New(t.In(2))
|
|
||||||
err := json.NewDecoder(r.Body).Decode(dv.Interface())
|
|
||||||
if err != nil {
|
|
||||||
writeResponse(w, http.StatusBadRequest, responseError("bad json"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
vs = v.Call([]reflect.Value{wv, rv, dv.Elem()})
|
|
||||||
default:
|
|
||||||
panic("JSONHandler: number of input parameter should be 2 or 3")
|
|
||||||
}
|
|
||||||
|
|
||||||
var e reflect.Value
|
|
||||||
switch len(vs) {
|
|
||||||
case 1:
|
|
||||||
// todo support other error types
|
|
||||||
if vs[0].IsZero() {
|
|
||||||
writeResponse(w, http.StatusOK, responseSuccess(nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
e = vs[0]
|
|
||||||
case 2:
|
|
||||||
if vs[1].IsZero() {
|
|
||||||
if !vs[0].IsZero() {
|
|
||||||
writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface()))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
e = vs[1]
|
|
||||||
default:
|
|
||||||
panic("JSONHandler: number of return values should be 1 or 2")
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.Type().AssignableTo(reflect.TypeOf(HTTPError{})) {
|
|
||||||
err := e.Interface().(HTTPError)
|
|
||||||
writeResponse(w, err.Code, responseError(err.Error()))
|
|
||||||
} else {
|
|
||||||
err := e.Interface().(error)
|
|
||||||
writeResponse(w, http.StatusBadRequest, responseError(err.Error()))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -5,9 +5,8 @@
|
|||||||
package tsweb
|
package tsweb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@ -26,7 +25,7 @@ type Response struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNewJSONHandler(t *testing.T) {
|
func TestNewJSONHandler(t *testing.T) {
|
||||||
checkStatus := func(w *httptest.ResponseRecorder, status string) *Response {
|
checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response {
|
||||||
d := &Response{
|
d := &Response{
|
||||||
Data: &Data{},
|
Data: &Data{},
|
||||||
}
|
}
|
||||||
@ -44,6 +43,10 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
t.Fatalf("wrong status: %s %s", d.Status, status)
|
t.Fatalf("wrong status: %s %s", d.Status, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if w.Code != code {
|
||||||
|
t.Fatalf("wrong status code: %d %d", w.Code, code)
|
||||||
|
}
|
||||||
|
|
||||||
if w.Header().Get("Content-Type") != "application/json" {
|
if w.Header().Get("Content-Type") != "application/json" {
|
||||||
t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
|
t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
|
||||||
}
|
}
|
||||||
@ -51,163 +54,139 @@ func TestNewJSONHandler(t *testing.T) {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2 1
|
h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error {
|
return http.StatusOK, nil, nil
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("2 1 simple", func(t *testing.T) {
|
t.Run("200 simple", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
h21.ServeHTTP(w, r)
|
h21.ServeHTTP(w, r)
|
||||||
checkStatus(w, "success")
|
checkStatus(w, "success", http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("2 1 HTTPError", func(t *testing.T) {
|
t.Run("403 HTTPError", func(t *testing.T) {
|
||||||
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError {
|
h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
return Error(http.StatusForbidden, "forbidden", nil)
|
return http.StatusForbidden, nil, fmt.Errorf("forbidden")
|
||||||
})
|
})
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
if w.Code != http.StatusForbidden {
|
checkStatus(w, "error", http.StatusForbidden)
|
||||||
t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// 2 2
|
h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
|
return http.StatusOK, &Data{Name: "tailscale"}, nil
|
||||||
return &Data{Name: "tailscale"}, nil
|
|
||||||
})
|
})
|
||||||
t.Run("2 2 get data", func(t *testing.T) {
|
|
||||||
|
t.Run("200 get data", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
h22.ServeHTTP(w, r)
|
h22.ServeHTTP(w, r)
|
||||||
checkStatus(w, "success")
|
checkStatus(w, "success", http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 3 1
|
h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error {
|
body := new(Data)
|
||||||
if d.Name == "" {
|
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||||
return errors.New("name is empty")
|
return http.StatusBadRequest, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
if body.Name == "" {
|
||||||
|
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "name is empty", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.StatusOK, nil, nil
|
||||||
})
|
})
|
||||||
t.Run("3 1 post data", func(t *testing.T) {
|
t.Run("200 post data", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
|
||||||
h31.ServeHTTP(w, r)
|
h31.ServeHTTP(w, r)
|
||||||
checkStatus(w, "success")
|
checkStatus(w, "success", http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("3 1 bad json", func(t *testing.T) {
|
t.Run("400 bad json", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
|
||||||
h31.ServeHTTP(w, r)
|
h31.ServeHTTP(w, r)
|
||||||
checkStatus(w, "error")
|
checkStatus(w, "error", http.StatusBadRequest)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("3 1 post data error", func(t *testing.T) {
|
t.Run("400 post data error", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||||
h31.ServeHTTP(w, r)
|
h31.ServeHTTP(w, r)
|
||||||
resp := checkStatus(w, "error")
|
resp := checkStatus(w, "error", http.StatusBadRequest)
|
||||||
if resp.Error != "name is empty" {
|
if resp.Error != "name is empty" {
|
||||||
t.Fatalf("wrong error")
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// 3 2
|
h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) {
|
body := new(Data)
|
||||||
if d.Price == 0 {
|
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||||
return nil, errors.New("price is empty")
|
return http.StatusBadRequest, nil, err
|
||||||
|
}
|
||||||
|
if body.Name == "root" {
|
||||||
|
return http.StatusInternalServerError, nil, fmt.Errorf("invalid name")
|
||||||
|
}
|
||||||
|
if body.Price == 0 {
|
||||||
|
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "price is empty", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Data{Price: d.Price * 2}, nil
|
return http.StatusOK, &Data{Price: body.Price * 2}, nil
|
||||||
})
|
})
|
||||||
t.Run("3 2 post data", func(t *testing.T) {
|
|
||||||
|
t.Run("200 post data", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
||||||
h32.ServeHTTP(w, r)
|
h32.ServeHTTP(w, r)
|
||||||
resp := checkStatus(w, "success")
|
resp := checkStatus(w, "success", http.StatusOK)
|
||||||
t.Log(resp.Data)
|
t.Log(resp.Data)
|
||||||
if resp.Data.Price != 20 {
|
if resp.Data.Price != 20 {
|
||||||
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("3 2 post data error", func(t *testing.T) {
|
t.Run("400 post data error", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||||
h32.ServeHTTP(w, r)
|
h32.ServeHTTP(w, r)
|
||||||
resp := checkStatus(w, "error")
|
resp := checkStatus(w, "error", http.StatusBadRequest)
|
||||||
if resp.Error != "price is empty" {
|
if resp.Error != "price is empty" {
|
||||||
t.Fatalf("wrong error")
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// fn check
|
t.Run("500 internal server error", func(t *testing.T) {
|
||||||
shouldPanic := func() {
|
|
||||||
r := recover()
|
|
||||||
if r == nil {
|
|
||||||
t.Fatalf("should panic")
|
|
||||||
}
|
|
||||||
t.Log(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("2 0 panic", func(t *testing.T) {
|
|
||||||
defer shouldPanic()
|
|
||||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("2 1 panic return value", func(t *testing.T) {
|
|
||||||
defer shouldPanic()
|
|
||||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) string {
|
|
||||||
return ""
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("2 1 panic arguments", func(t *testing.T) {
|
|
||||||
defer shouldPanic()
|
|
||||||
JSONHandler(func(r *http.Request, w http.ResponseWriter) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("3 1 panic arguments", func(t *testing.T) {
|
|
||||||
defer shouldPanic()
|
|
||||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("3 2 panic return value", func(t *testing.T) {
|
|
||||||
defer shouldPanic()
|
|
||||||
//lint:ignore ST1008 intentional
|
|
||||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) {
|
|
||||||
return nil, "panic"
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("2 2 forbidden", func(t *testing.T) {
|
|
||||||
code := http.StatusForbidden
|
|
||||||
body := []byte("forbidden")
|
|
||||||
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
|
|
||||||
w.WriteHeader(code)
|
|
||||||
w.Write(body)
|
|
||||||
return nil, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
|
||||||
h.ServeHTTP(w, r)
|
h32.ServeHTTP(w, r)
|
||||||
if w.Code != http.StatusForbidden {
|
resp := checkStatus(w, "error", http.StatusInternalServerError)
|
||||||
t.Fatalf("wrong code: %d %d", w.Code, code)
|
if resp.Error != "internal server error" {
|
||||||
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(w.Body.Bytes(), []byte("forbidden")) {
|
})
|
||||||
t.Fatalf("wrong body: %s %s", w.Body.Bytes(), body)
|
|
||||||
|
t.Run("500 misuse", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("POST", "/", nil)
|
||||||
|
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||||
|
return http.StatusOK, make(chan int), nil
|
||||||
|
}).ServeHTTP(w, r)
|
||||||
|
resp := checkStatus(w, "error", http.StatusInternalServerError)
|
||||||
|
if resp.Error != "json marshal error" {
|
||||||
|
t.Fatalf("wrong error")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("500 empty status code", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("POST", "/", nil)
|
||||||
|
JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
|
||||||
|
return
|
||||||
|
}).ServeHTTP(w, r)
|
||||||
|
checkStatus(w, "error", http.StatusInternalServerError)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user