From 665a3cc666e4e6a7442b1019891e7fe1708fa276 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 22 Jun 2023 16:38:57 +0200 Subject: [PATCH] add generic logerr func to shorten code Signed-off-by: Kristoffer Dalby --- hscontrol/oidc.go | 151 +++++++++++------------------------------- hscontrol/util/log.go | 7 ++ 2 files changed, 44 insertions(+), 114 deletions(-) create mode 100644 hscontrol/util/log.go diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 732cb446..66383838 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -111,10 +111,7 @@ func (h *Headscale) RegisterOIDC( writer.WriteHeader(http.StatusUnauthorized) _, err := writer.Write([]byte("Unauthorized")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return @@ -137,10 +134,7 @@ func (h *Headscale) RegisterOIDC( writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("Wrong params")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return @@ -148,9 +142,8 @@ func (h *Headscale) RegisterOIDC( randomBlob := make([]byte, randomByteSize) if _, err := rand.Read(randomBlob); err != nil { - log.Error(). - Caller(). - Msg("could not read 16 bytes from rand") + util.LogErr(err, "could not read 16 bytes from rand") + http.Error(writer, "Internal server error", http.StatusInternalServerError) return @@ -274,10 +267,7 @@ func (h *Headscale) OIDCCallback( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) if _, err := writer.Write(content.Bytes()); err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } } @@ -293,10 +283,7 @@ func validateOIDCCallbackParams( writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("Wrong params")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return "", "", errEmptyOIDCCallbackParams @@ -312,18 +299,12 @@ func (h *Headscale) getIDTokenForOIDCCallback( ) (string, error) { oauth2Token, err := h.oauth2Config.Exchange(ctx, code) if err != nil { - log.Error(). - Err(err). - Caller(). - Msg("Could not exchange code for token") + util.LogErr(err, "Could not exchange code for token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, werr := writer.Write([]byte("Could not exchange code for token")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return "", err @@ -341,10 +322,7 @@ func (h *Headscale) getIDTokenForOIDCCallback( writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("Could not extract ID Token")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return "", errNoOIDCIDToken @@ -361,18 +339,12 @@ func (h *Headscale) verifyIDTokenForOIDCCallback( verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { - log.Error(). - Err(err). - Caller(). - Msg("failed to verify id token") + util.LogErr(err, "failed to verify id token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, werr := writer.Write([]byte("Failed to verify id token")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, err @@ -387,18 +359,13 @@ func extractIDTokenClaims( ) (*IDTokenClaims, error) { var claims IDTokenClaims if err := idToken.Claims(&claims); err != nil { - log.Error(). - Err(err). - Caller(). - Msg("Failed to decode id token claims") + util.LogErr(err, "Failed to decode id token claims") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, werr := writer.Write([]byte("Failed to decode id token claims")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, err @@ -417,15 +384,13 @@ func validateOIDCAllowedDomains( if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || !util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) { - log.Error().Msg("authenticated principal does not match any allowed domain") + log.Trace().Msg("authenticated principal does not match any allowed domain") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("unauthorized principal (domain mismatch)")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return errOIDCAllowedDomains @@ -451,15 +416,12 @@ func validateOIDCAllowedGroups( } } - log.Error().Msg("authenticated principal not in any allowed groups") + log.Trace().Msg("authenticated principal not in any allowed groups") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("unauthorized principal (allowed groups)")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return errOIDCAllowedGroups @@ -477,15 +439,12 @@ func validateOIDCAllowedUsers( ) error { if len(allowedUsers) > 0 && !util.IsStringInSlice(allowedUsers, claims.Email) { - log.Error().Msg("authenticated principal does not match any allowed user") + log.Trace().Msg("authenticated principal does not match any allowed user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("unauthorized principal (user mismatch)")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return errOIDCAllowedUsers @@ -507,16 +466,13 @@ func (h *Headscale) validateMachineForOIDCCallback( // retrieve machinekey from state cache nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) if !nodeKeyFound { - log.Error(). + log.Trace(). Msg("requested machine state key expired before authorisation completed") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("state has expired")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, false, errOIDCNodeKeyMissing @@ -525,16 +481,13 @@ func (h *Headscale) validateMachineForOIDCCallback( var nodeKey key.NodePublic nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) if !nodeKeyOK { - log.Error(). + log.Trace(). Msg("requested machine state key is not a string") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("state is invalid")) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, false, errOIDCInvalidMachineState @@ -552,10 +505,7 @@ func (h *Headscale) validateMachineForOIDCCallback( writer.WriteHeader(http.StatusBadRequest) _, werr := writer.Write([]byte("could not parse node public key")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, false, err @@ -575,10 +525,7 @@ func (h *Headscale) validateMachineForOIDCCallback( err := h.db.RefreshMachine(machine, expiry) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to refresh machine") + util.LogErr(err, "Failed to refresh machine") http.Error( writer, "Failed to refresh machine", @@ -607,10 +554,7 @@ func (h *Headscale) validateMachineForOIDCCallback( writer.WriteHeader(http.StatusInternalServerError) _, werr := writer.Write([]byte("Could not render OIDC callback template")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, true, err @@ -620,10 +564,7 @@ func (h *Headscale) validateMachineForOIDCCallback( writer.WriteHeader(http.StatusOK) _, err = writer.Write(content.Bytes()) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, true, nil @@ -642,15 +583,13 @@ func getUserName( stripEmaildomain, ) if err != nil { - log.Error().Err(err).Caller().Msgf("couldn't normalize email") + util.LogErr(err, "couldn't normalize email") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) _, werr := writer.Write([]byte("couldn't normalize email")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return "", err @@ -666,7 +605,6 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( user, err := h.db.GetUser(userName) if errors.Is(err, db.ErrUserNotFound) { user, err = h.db.CreateUser(userName) - if err != nil { log.Error(). Err(err). @@ -676,10 +614,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( writer.WriteHeader(http.StatusInternalServerError) _, werr := writer.Write([]byte("could not create user")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, err @@ -694,10 +629,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( writer.WriteHeader(http.StatusInternalServerError) _, werr := writer.Write([]byte("could not find or create user")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, err @@ -720,18 +652,12 @@ func (h *Headscale) registerMachineForOIDCCallback( &expiry, util.RegisterMethodOIDC, ); err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not register machine") + util.LogErr(err, "could not register machine") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) _, werr := writer.Write([]byte("could not register machine")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return err @@ -759,10 +685,7 @@ func renderOIDCCallbackTemplate( writer.WriteHeader(http.StatusInternalServerError) _, werr := writer.Write([]byte("Could not render OIDC callback template")) if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write response") + util.LogErr(err, "Failed to write response") } return nil, err diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go new file mode 100644 index 00000000..ebbdb792 --- /dev/null +++ b/hscontrol/util/log.go @@ -0,0 +1,7 @@ +package util + +import "github.com/rs/zerolog/log" + +func LogErr(err error, msg string) { + log.Error().Caller().Err(err).Msg(msg) +}