add generic logerr func to shorten code

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-22 16:38:57 +02:00 committed by Kristoffer Dalby
parent fe75b71620
commit 665a3cc666
2 changed files with 44 additions and 114 deletions

View File

@ -111,10 +111,7 @@ func (h *Headscale) RegisterOIDC(
writer.WriteHeader(http.StatusUnauthorized) writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized")) _, err := writer.Write([]byte("Unauthorized"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return return
@ -137,10 +134,7 @@ func (h *Headscale) RegisterOIDC(
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params")) _, err := writer.Write([]byte("Wrong params"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return return
@ -148,9 +142,8 @@ func (h *Headscale) RegisterOIDC(
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil { if _, err := rand.Read(randomBlob); err != nil {
log.Error(). util.LogErr(err, "could not read 16 bytes from rand")
Caller().
Msg("could not read 16 bytes from rand")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
@ -274,10 +267,7 @@ func (h *Headscale) OIDCCallback(
writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil { if _, err := writer.Write(content.Bytes()); err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
} }
@ -293,10 +283,7 @@ func validateOIDCCallbackParams(
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params")) _, err := writer.Write([]byte("Wrong params"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return "", "", errEmptyOIDCCallbackParams return "", "", errEmptyOIDCCallbackParams
@ -312,18 +299,12 @@ func (h *Headscale) getIDTokenForOIDCCallback(
) (string, error) { ) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code) oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Could not exchange code for token")
Err(err).
Caller().
Msg("Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token")) _, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return "", err return "", err
@ -341,10 +322,7 @@ func (h *Headscale) getIDTokenForOIDCCallback(
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token")) _, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return "", errNoOIDCIDToken return "", errNoOIDCIDToken
@ -361,18 +339,12 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken) idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "failed to verify id token")
Err(err).
Caller().
Msg("failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token")) _, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, err return nil, err
@ -387,18 +359,13 @@ func extractIDTokenClaims(
) (*IDTokenClaims, error) { ) (*IDTokenClaims, error) {
var claims IDTokenClaims var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil { if err := idToken.Claims(&claims); err != nil {
log.Error(). util.LogErr(err, "Failed to decode id token claims")
Err(err).
Caller().
Msg("Failed to decode id token claims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims")) _, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, err return nil, err
@ -417,15 +384,13 @@ func validateOIDCAllowedDomains(
if len(allowedDomains) > 0 { if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 || if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) { !util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
log.Error().Msg("authenticated principal does not match any allowed domain") log.Trace().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)")) _, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return errOIDCAllowedDomains return errOIDCAllowedDomains
@ -451,15 +416,12 @@ func validateOIDCAllowedGroups(
} }
} }
log.Error().Msg("authenticated principal not in any allowed groups") log.Trace().Msg("authenticated principal not in any allowed groups")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (allowed groups)")) _, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return errOIDCAllowedGroups return errOIDCAllowedGroups
@ -477,15 +439,12 @@ func validateOIDCAllowedUsers(
) error { ) error {
if len(allowedUsers) > 0 && if len(allowedUsers) > 0 &&
!util.IsStringInSlice(allowedUsers, claims.Email) { !util.IsStringInSlice(allowedUsers, claims.Email) {
log.Error().Msg("authenticated principal does not match any allowed user") log.Trace().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)")) _, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return errOIDCAllowedUsers return errOIDCAllowedUsers
@ -507,16 +466,13 @@ func (h *Headscale) validateMachineForOIDCCallback(
// retrieve machinekey from state cache // retrieve machinekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !nodeKeyFound { if !nodeKeyFound {
log.Error(). log.Trace().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired")) _, err := writer.Write([]byte("state has expired"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return nil, false, errOIDCNodeKeyMissing return nil, false, errOIDCNodeKeyMissing
@ -525,16 +481,13 @@ func (h *Headscale) validateMachineForOIDCCallback(
var nodeKey key.NodePublic var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
if !nodeKeyOK { if !nodeKeyOK {
log.Error(). log.Trace().
Msg("requested machine state key is not a string") Msg("requested machine state key is not a string")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid")) _, err := writer.Write([]byte("state is invalid"))
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return nil, false, errOIDCInvalidMachineState return nil, false, errOIDCInvalidMachineState
@ -552,10 +505,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse node public key")) _, werr := writer.Write([]byte("could not parse node public key"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, false, err return nil, false, err
@ -575,10 +525,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
err := h.db.RefreshMachine(machine, expiry) err := h.db.RefreshMachine(machine, expiry)
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to refresh machine")
Caller().
Err(err).
Msg("Failed to refresh machine")
http.Error( http.Error(
writer, writer,
"Failed to refresh machine", "Failed to refresh machine",
@ -607,10 +554,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template")) _, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, true, err return nil, true, err
@ -620,10 +564,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes()) _, err = writer.Write(content.Bytes())
if err != nil { if err != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
return nil, true, nil return nil, true, nil
@ -642,15 +583,13 @@ func getUserName(
stripEmaildomain, stripEmaildomain,
) )
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email") util.LogErr(err, "couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email")) _, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return "", err return "", err
@ -666,7 +605,6 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
user, err := h.db.GetUser(userName) user, err := h.db.GetUser(userName)
if errors.Is(err, db.ErrUserNotFound) { if errors.Is(err, db.ErrUserNotFound) {
user, err = h.db.CreateUser(userName) user, err = h.db.CreateUser(userName)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -676,10 +614,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create user")) _, werr := writer.Write([]byte("could not create user"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, err return nil, err
@ -694,10 +629,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create user")) _, werr := writer.Write([]byte("could not find or create user"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, err return nil, err
@ -720,18 +652,12 @@ func (h *Headscale) registerMachineForOIDCCallback(
&expiry, &expiry,
util.RegisterMethodOIDC, util.RegisterMethodOIDC,
); err != nil { ); err != nil {
log.Error(). util.LogErr(err, "could not register machine")
Caller().
Err(err).
Msg("could not register machine")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register machine")) _, werr := writer.Write([]byte("could not register machine"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return err return err
@ -759,10 +685,7 @@ func renderOIDCCallbackTemplate(
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template")) _, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil { if werr != nil {
log.Error(). util.LogErr(err, "Failed to write response")
Caller().
Err(werr).
Msg("Failed to write response")
} }
return nil, err return nil, err

7
hscontrol/util/log.go Normal file
View File

@ -0,0 +1,7 @@
package util
import "github.com/rs/zerolog/log"
func LogErr(err error, msg string) {
log.Error().Caller().Err(err).Msg(msg)
}