mirror of
https://github.com/restic/restic.git
synced 2025-12-12 07:41:50 +00:00
termstatus: fully wrap reading password from terminal
This commit is contained in:
@@ -3,7 +3,7 @@ package terminal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"io"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
@@ -12,11 +12,10 @@ import (
|
||||
// tty. Prompt is printed on the writer out before attempting to read the
|
||||
// password. If the context is canceled, the function leaks the password reading
|
||||
// goroutine.
|
||||
func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string) (password string, err error) {
|
||||
fd := int(out.Fd())
|
||||
state, err := term.GetState(fd)
|
||||
func ReadPassword(ctx context.Context, inFd int, out io.Writer, prompt string) (password string, err error) {
|
||||
state, err := term.GetState(inFd)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err)
|
||||
_, _ = fmt.Fprintf(out, "unable to get terminal state: %v\n", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -29,7 +28,7 @@ func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buf, err = term.ReadPassword(int(in.Fd()))
|
||||
buf, err = term.ReadPassword(inFd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -38,9 +37,9 @@ func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := term.Restore(fd, state)
|
||||
err := term.Restore(inFd, state)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err)
|
||||
_, _ = fmt.Fprintf(out, "unable to restore terminal state: %v\n", err)
|
||||
}
|
||||
return "", ctx.Err()
|
||||
case <-done:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package ui
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
var _ Terminal = &MockTerminal{}
|
||||
|
||||
@@ -33,6 +36,10 @@ func (m *MockTerminal) InputIsTerminal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MockTerminal) ReadPassword(_ context.Context, _ string) (string, error) {
|
||||
return "password", nil
|
||||
}
|
||||
|
||||
func (m *MockTerminal) OutputRaw() io.Writer {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package ui
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Terminal is used to write messages and display status lines which can be
|
||||
// updated. See termstatus.Terminal for a concrete implementation.
|
||||
@@ -15,6 +18,7 @@ type Terminal interface {
|
||||
CanUpdateStatus() bool
|
||||
InputRaw() io.ReadCloser
|
||||
InputIsTerminal() bool
|
||||
ReadPassword(ctx context.Context, prompt string) (string, error)
|
||||
// OutputRaw returns the output writer. Should only be used if there is no
|
||||
// other option. Must not be used in combination with Print, Error, SetStatus
|
||||
// or any other method that writes to the terminal.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package termstatus
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -18,6 +19,7 @@ var _ ui.Terminal = &Terminal{}
|
||||
// printed.
|
||||
type Terminal struct {
|
||||
rd io.ReadCloser
|
||||
inFd uintptr
|
||||
wr io.Writer
|
||||
fd uintptr
|
||||
errWriter io.Writer
|
||||
@@ -100,6 +102,7 @@ func New(rd io.ReadCloser, wr io.Writer, errWriter io.Writer, disableStatus bool
|
||||
|
||||
if d, ok := rd.(fder); ok {
|
||||
if terminal.InputIsTerminal(d.Fd()) {
|
||||
t.inFd = d.Fd()
|
||||
t.inputIsTerminal = true
|
||||
}
|
||||
}
|
||||
@@ -130,6 +133,26 @@ func (t *Terminal) InputRaw() io.ReadCloser {
|
||||
return t.rd
|
||||
}
|
||||
|
||||
func (t *Terminal) ReadPassword(ctx context.Context, prompt string) (string, error) {
|
||||
if t.InputIsTerminal() {
|
||||
return terminal.ReadPassword(ctx, int(t.inFd), t.errWriter, prompt)
|
||||
}
|
||||
if t.OutputIsTerminal() {
|
||||
t.Print("reading repository password from stdin")
|
||||
}
|
||||
return readPassword(t.rd)
|
||||
}
|
||||
|
||||
// readPassword reads the password from the given reader directly.
|
||||
func readPassword(in io.Reader) (password string, err error) {
|
||||
sc := bufio.NewScanner(in)
|
||||
sc.Scan()
|
||||
if sc.Err() != nil {
|
||||
return "", fmt.Errorf("readPassword: %w", sc.Err())
|
||||
}
|
||||
return sc.Text(), nil
|
||||
}
|
||||
|
||||
// CanUpdateStatus return whether the status output is updated in place.
|
||||
func (t *Terminal) CanUpdateStatus() bool {
|
||||
return t.canUpdateStatus
|
||||
|
||||
@@ -3,6 +3,7 @@ package termstatus
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
@@ -76,3 +77,13 @@ func TestSanitizeLines(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errorReader struct{ err error }
|
||||
|
||||
func (r *errorReader) Read([]byte) (int, error) { return 0, r.err }
|
||||
|
||||
func TestReadPassword(t *testing.T) {
|
||||
want := errors.New("foo")
|
||||
_, err := readPassword(&errorReader{want})
|
||||
rtest.Assert(t, errors.Is(err, want), "wrong error %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user