termstatus: fully wrap reading password from terminal

This commit is contained in:
Michael Eischer
2025-09-14 19:21:51 +02:00
parent 013c565c29
commit ff5a0cc851
7 changed files with 56 additions and 44 deletions

View File

@@ -1,7 +1,6 @@
package main
import (
"bufio"
"context"
"fmt"
"io"
@@ -32,7 +31,6 @@ import (
"github.com/restic/restic/internal/options"
"github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/terminal"
"github.com/restic/restic/internal/textfile"
"github.com/restic/restic/internal/ui"
"github.com/restic/restic/internal/ui/progress"
@@ -232,14 +230,6 @@ func loadPasswordFromFile(pwdFile string) (string, error) {
return strings.TrimSpace(string(s)), errors.Wrap(err, "Readfile")
}
// readPassword reads the password from the given reader directly.
func readPassword(in io.Reader) (password string, err error) {
sc := bufio.NewScanner(in)
sc.Scan()
return sc.Text(), errors.WithStack(sc.Err())
}
// ReadPassword reads the password from a password file, the environment
// variable RESTIC_PASSWORD or prompts the user. If the context is canceled,
// the function leaks the password reading goroutine.
@@ -255,20 +245,9 @@ func ReadPassword(ctx context.Context, gopts GlobalOptions, prompt string, print
return gopts.password, nil
}
var (
password string
err error
)
if gopts.term.InputIsTerminal() {
password, err = terminal.ReadPassword(ctx, os.Stdin, os.Stderr, prompt)
} else {
printer.PT("reading repository password from stdin")
password, err = readPassword(os.Stdin)
}
password, err := gopts.term.ReadPassword(ctx, prompt)
if err != nil {
return "", errors.Wrap(err, "unable to read password")
return "", fmt.Errorf("unable to read password: %w", err)
}
if len(password) == 0 {

View File

@@ -7,21 +7,10 @@ import (
"strings"
"testing"
"github.com/restic/restic/internal/errors"
rtest "github.com/restic/restic/internal/test"
"github.com/restic/restic/internal/ui/progress"
)
type errorReader struct{ err error }
func (r *errorReader) Read([]byte) (int, error) { return 0, r.err }
func TestReadPassword(t *testing.T) {
want := errors.New("foo")
_, err := readPassword(&errorReader{want})
rtest.Assert(t, errors.Is(err, want), "wrong error %v", err)
}
func TestReadRepo(t *testing.T) {
tempDir := rtest.TempDir(t)

View File

@@ -3,7 +3,7 @@ package terminal
import (
"context"
"fmt"
"os"
"io"
"golang.org/x/term"
)
@@ -12,11 +12,10 @@ import (
// tty. Prompt is printed on the writer out before attempting to read the
// password. If the context is canceled, the function leaks the password reading
// goroutine.
func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string) (password string, err error) {
fd := int(out.Fd())
state, err := term.GetState(fd)
func ReadPassword(ctx context.Context, inFd int, out io.Writer, prompt string) (password string, err error) {
state, err := term.GetState(inFd)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err)
_, _ = fmt.Fprintf(out, "unable to get terminal state: %v\n", err)
return "", err
}
@@ -29,7 +28,7 @@ func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string)
if err != nil {
return
}
buf, err = term.ReadPassword(int(in.Fd()))
buf, err = term.ReadPassword(inFd)
if err != nil {
return
}
@@ -38,9 +37,9 @@ func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string)
select {
case <-ctx.Done():
err := term.Restore(fd, state)
err := term.Restore(inFd, state)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err)
_, _ = fmt.Fprintf(out, "unable to restore terminal state: %v\n", err)
}
return "", ctx.Err()
case <-done:

View File

@@ -1,6 +1,9 @@
package ui
import "io"
import (
"context"
"io"
)
var _ Terminal = &MockTerminal{}
@@ -33,6 +36,10 @@ func (m *MockTerminal) InputIsTerminal() bool {
return true
}
func (m *MockTerminal) ReadPassword(_ context.Context, _ string) (string, error) {
return "password", nil
}
func (m *MockTerminal) OutputRaw() io.Writer {
return nil
}

View File

@@ -1,6 +1,9 @@
package ui
import "io"
import (
"context"
"io"
)
// Terminal is used to write messages and display status lines which can be
// updated. See termstatus.Terminal for a concrete implementation.
@@ -15,6 +18,7 @@ type Terminal interface {
CanUpdateStatus() bool
InputRaw() io.ReadCloser
InputIsTerminal() bool
ReadPassword(ctx context.Context, prompt string) (string, error)
// OutputRaw returns the output writer. Should only be used if there is no
// other option. Must not be used in combination with Print, Error, SetStatus
// or any other method that writes to the terminal.

View File

@@ -1,6 +1,7 @@
package termstatus
import (
"bufio"
"context"
"fmt"
"io"
@@ -18,6 +19,7 @@ var _ ui.Terminal = &Terminal{}
// printed.
type Terminal struct {
rd io.ReadCloser
inFd uintptr
wr io.Writer
fd uintptr
errWriter io.Writer
@@ -100,6 +102,7 @@ func New(rd io.ReadCloser, wr io.Writer, errWriter io.Writer, disableStatus bool
if d, ok := rd.(fder); ok {
if terminal.InputIsTerminal(d.Fd()) {
t.inFd = d.Fd()
t.inputIsTerminal = true
}
}
@@ -130,6 +133,26 @@ func (t *Terminal) InputRaw() io.ReadCloser {
return t.rd
}
func (t *Terminal) ReadPassword(ctx context.Context, prompt string) (string, error) {
if t.InputIsTerminal() {
return terminal.ReadPassword(ctx, int(t.inFd), t.errWriter, prompt)
}
if t.OutputIsTerminal() {
t.Print("reading repository password from stdin")
}
return readPassword(t.rd)
}
// readPassword reads the password from the given reader directly.
func readPassword(in io.Reader) (password string, err error) {
sc := bufio.NewScanner(in)
sc.Scan()
if sc.Err() != nil {
return "", fmt.Errorf("readPassword: %w", sc.Err())
}
return sc.Text(), nil
}
// CanUpdateStatus return whether the status output is updated in place.
func (t *Terminal) CanUpdateStatus() bool {
return t.canUpdateStatus

View File

@@ -3,6 +3,7 @@ package termstatus
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"testing"
@@ -76,3 +77,13 @@ func TestSanitizeLines(t *testing.T) {
})
}
}
type errorReader struct{ err error }
func (r *errorReader) Read([]byte) (int, error) { return 0, r.err }
func TestReadPassword(t *testing.T) {
want := errors.New("foo")
_, err := readPassword(&errorReader{want})
rtest.Assert(t, errors.Is(err, want), "wrong error %v", err)
}