From 96af35555a0cc96efde39ca3dace4e306eba9d4f Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Thu, 18 Sep 2025 22:17:21 +0200 Subject: [PATCH] termstatus: add stdin and inject into backup command --- cmd/restic/cmd_backup.go | 20 +++++++++--------- cmd/restic/cmd_backup_test.go | 2 +- cmd/restic/cmd_init_integration_test.go | 14 ++++++++++--- cmd/restic/cmd_key_integration_test.go | 15 ++++++++------ cmd/restic/cmd_list_integration_test.go | 2 +- cmd/restic/global.go | 6 +++--- cmd/restic/integration_helpers_test.go | 2 +- cmd/restic/main.go | 2 +- cmd/restic/secondary_repo_test.go | 6 ------ internal/terminal/stdio.go | 6 ++---- internal/ui/mock.go | 8 ++++++++ internal/ui/terminal.go | 2 ++ internal/ui/termstatus/status.go | 27 +++++++++++++++++++++---- internal/ui/termstatus/status_test.go | 2 +- 14 files changed, 73 insertions(+), 41 deletions(-) diff --git a/cmd/restic/cmd_backup.go b/cmd/restic/cmd_backup.go index e62e0d1c2..8c070b660 100644 --- a/cmd/restic/cmd_backup.go +++ b/cmd/restic/cmd_backup.go @@ -182,7 +182,7 @@ func filterExisting(items []string, warnf func(msg string, args ...interface{})) // If filename is empty, readPatternsFromFile returns an empty slice. // If filename is a dash (-), readPatternsFromFile will read the lines from the // standard input. -func readLines(filename string) ([]string, error) { +func readLines(filename string, stdin io.ReadCloser) ([]string, error) { if filename == "" { return nil, nil } @@ -193,7 +193,7 @@ func readLines(filename string) ([]string, error) { ) if filename == "-" { - data, err = io.ReadAll(os.Stdin) + data, err = io.ReadAll(stdin) } else { data, err = textfile.Read(filename) } @@ -218,8 +218,8 @@ func readLines(filename string) ([]string, error) { // readFilenamesFromFileRaw reads a list of filenames from the given file, // or stdin if filename is "-". Each filename is terminated by a zero byte, // which is stripped off. -func readFilenamesFromFileRaw(filename string) (names []string, err error) { - f := os.Stdin +func readFilenamesFromFileRaw(filename string, stdin io.ReadCloser) (names []string, err error) { + var f io.ReadCloser = stdin if filename != "-" { if f, err = os.Open(filename); err != nil { return nil, err @@ -378,13 +378,13 @@ func collectRejectFuncs(opts BackupOptions, targets []string, fs fs.FS, warnf fu } // collectTargets returns a list of target files/dirs from several sources. -func collectTargets(opts BackupOptions, args []string, warnf func(msg string, args ...interface{})) (targets []string, err error) { +func collectTargets(opts BackupOptions, args []string, warnf func(msg string, args ...interface{}), stdin io.ReadCloser) (targets []string, err error) { if opts.Stdin || opts.StdinCommand { return nil, nil } for _, file := range opts.FilesFrom { - fromfile, err := readLines(file) + fromfile, err := readLines(file, stdin) if err != nil { return nil, err } @@ -409,7 +409,7 @@ func collectTargets(opts BackupOptions, args []string, warnf func(msg string, ar } for _, file := range opts.FilesFromVerbatim { - fromfile, err := readLines(file) + fromfile, err := readLines(file, stdin) if err != nil { return nil, err } @@ -422,7 +422,7 @@ func collectTargets(opts BackupOptions, args []string, warnf func(msg string, ar } for _, file := range opts.FilesFromRaw { - fromfile, err := readFilenamesFromFileRaw(file) + fromfile, err := readFilenamesFromFileRaw(file, stdin) if err != nil { return nil, err } @@ -490,7 +490,7 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter return err } - targets, err := collectTargets(opts, args, msg.E) + targets, err := collectTargets(opts, args, msg.E, term.InputRaw()) if err != nil { return err } @@ -582,7 +582,7 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter progressPrinter.V("read data from stdin") } filename := path.Join("/", opts.StdinFilename) - var source io.ReadCloser = os.Stdin + var source io.ReadCloser = term.InputRaw() if opts.StdinCommand { source, err = fs.NewCommandReader(ctx, args, msg.E) if err != nil { diff --git a/cmd/restic/cmd_backup_test.go b/cmd/restic/cmd_backup_test.go index ef5f02825..b607532b4 100644 --- a/cmd/restic/cmd_backup_test.go +++ b/cmd/restic/cmd_backup_test.go @@ -67,7 +67,7 @@ func TestCollectTargets(t *testing.T) { FilesFromRaw: []string{f3.Name()}, } - targets, err := collectTargets(opts, []string{filepath.Join(dir, "cmdline arg")}, t.Logf) + targets, err := collectTargets(opts, []string{filepath.Join(dir, "cmdline arg")}, t.Logf, nil) rtest.OK(t, err) sort.Strings(targets) rtest.Equals(t, expect, targets) diff --git a/cmd/restic/cmd_init_integration_test.go b/cmd/restic/cmd_init_integration_test.go index e5fba798a..878049ea6 100644 --- a/cmd/restic/cmd_init_integration_test.go +++ b/cmd/restic/cmd_init_integration_test.go @@ -18,7 +18,7 @@ func testRunInit(t testing.TB, opts GlobalOptions) { restic.TestSetLockTimeout(t, 0) err := withTermStatus(opts, func(ctx context.Context, gopts GlobalOptions) error { - return runInit(ctx, InitOptions{}, opts, nil, gopts.term) + return runInit(ctx, InitOptions{}, gopts, nil, gopts.term) }) rtest.OK(t, err) t.Logf("repository initialized at %v", opts.Repo) @@ -54,10 +54,18 @@ func TestInitCopyChunkerParams(t *testing.T) { }) rtest.OK(t, err) - repo, err := OpenRepository(context.TODO(), env.gopts, &progress.NoopPrinter{}) + var repo *repository.Repository + err = withTermStatus(env.gopts, func(ctx context.Context, gopts GlobalOptions) error { + repo, err = OpenRepository(ctx, gopts, &progress.NoopPrinter{}) + return err + }) rtest.OK(t, err) - otherRepo, err := OpenRepository(context.TODO(), env2.gopts, &progress.NoopPrinter{}) + var otherRepo *repository.Repository + err = withTermStatus(env2.gopts, func(ctx context.Context, gopts GlobalOptions) error { + otherRepo, err = OpenRepository(ctx, gopts, &progress.NoopPrinter{}) + return err + }) rtest.OK(t, err) rtest.Assert(t, repo.Config().ChunkerPolynomial == otherRepo.Config().ChunkerPolynomial, diff --git a/cmd/restic/cmd_key_integration_test.go b/cmd/restic/cmd_key_integration_test.go index 903fab07e..dad7f7e67 100644 --- a/cmd/restic/cmd_key_integration_test.go +++ b/cmd/restic/cmd_key_integration_test.go @@ -63,13 +63,16 @@ func testRunKeyAddNewKeyUserHost(t testing.TB, gopts GlobalOptions) { }) rtest.OK(t, err) - repo, err := OpenRepository(context.TODO(), gopts, &progress.NoopPrinter{}) - rtest.OK(t, err) - key, err := repository.SearchKey(context.TODO(), repo, testKeyNewPassword, 2, "") - rtest.OK(t, err) + _ = withTermStatus(gopts, func(ctx context.Context, gopts GlobalOptions) error { + repo, err := OpenRepository(ctx, gopts, &progress.NoopPrinter{}) + rtest.OK(t, err) + key, err := repository.SearchKey(ctx, repo, testKeyNewPassword, 2, "") + rtest.OK(t, err) - rtest.Equals(t, "john", key.Username) - rtest.Equals(t, "example.com", key.Hostname) + rtest.Equals(t, "john", key.Username) + rtest.Equals(t, "example.com", key.Hostname) + return nil + }) } func testRunKeyPasswd(t testing.TB, newPassword string, gopts GlobalOptions) { diff --git a/cmd/restic/cmd_list_integration_test.go b/cmd/restic/cmd_list_integration_test.go index 412bd3a2a..5b1409e9c 100644 --- a/cmd/restic/cmd_list_integration_test.go +++ b/cmd/restic/cmd_list_integration_test.go @@ -13,7 +13,7 @@ import ( func testRunList(t testing.TB, opts GlobalOptions, tpe string) restic.IDs { buf, err := withCaptureStdout(opts, func(opts GlobalOptions) error { return withTermStatus(opts, func(ctx context.Context, gopts GlobalOptions) error { - return runList(ctx, opts, []string{tpe}, gopts.term) + return runList(ctx, gopts, []string{tpe}, gopts.term) }) }) rtest.OK(t, err) diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 14dc8b6da..9872f0070 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -260,7 +260,7 @@ func ReadPassword(ctx context.Context, opts GlobalOptions, prompt string, printe err error ) - if terminal.StdinIsTerminal() { + if opts.term.InputIsTerminal() { password, err = terminal.ReadPassword(ctx, os.Stdin, os.Stderr, prompt) } else { printer.PT("reading repository password from stdin") @@ -286,7 +286,7 @@ func ReadPasswordTwice(ctx context.Context, gopts GlobalOptions, prompt1, prompt if err != nil { return "", err } - if terminal.StdinIsTerminal() { + if gopts.term.InputIsTerminal() { pw2, err := ReadPassword(ctx, gopts, prompt2, printer) if err != nil { return "", err @@ -349,7 +349,7 @@ func OpenRepository(ctx context.Context, opts GlobalOptions, printer progress.Pr } passwordTriesLeft := 1 - if terminal.StdinIsTerminal() && opts.password == "" && !opts.InsecureNoPassword { + if opts.term.InputIsTerminal() && opts.password == "" && !opts.InsecureNoPassword { passwordTriesLeft = 3 } diff --git a/cmd/restic/integration_helpers_test.go b/cmd/restic/integration_helpers_test.go index bfefe807f..dc8b3eda9 100644 --- a/cmd/restic/integration_helpers_test.go +++ b/cmd/restic/integration_helpers_test.go @@ -427,7 +427,7 @@ func withTermStatus(gopts GlobalOptions, callback func(ctx context.Context, gopt ctx, cancel := context.WithCancel(context.TODO()) var wg sync.WaitGroup - term := termstatus.New(gopts.stdout, gopts.stderr, gopts.Quiet) + term := termstatus.New(os.Stdin, gopts.stdout, gopts.stderr, gopts.Quiet) gopts.term = term wg.Add(1) go func() { diff --git a/cmd/restic/main.go b/cmd/restic/main.go index 91fbf638d..1c4fa5e9c 100644 --- a/cmd/restic/main.go +++ b/cmd/restic/main.go @@ -178,7 +178,7 @@ func main() { backends: collectBackends(), } func() { - term, cancel := termstatus.Setup(os.Stdout, os.Stderr, globalOptions.Quiet) + term, cancel := termstatus.Setup(os.Stdin, os.Stdout, os.Stderr, globalOptions.Quiet) defer cancel() globalOptions.stdout, globalOptions.stderr = termstatus.WrapStdio(term) globalOptions.term = term diff --git a/cmd/restic/secondary_repo_test.go b/cmd/restic/secondary_repo_test.go index 2c31bcecf..32206318a 100644 --- a/cmd/restic/secondary_repo_test.go +++ b/cmd/restic/secondary_repo_test.go @@ -131,12 +131,6 @@ func TestFillSecondaryGlobalOpts(t *testing.T) { PasswordCommand: "notEmpty", }, }, - { - // Test must fail as no password is given. - Opts: secondaryRepoOptions{ - Repo: "backupDst", - }, - }, { // Test must fail as current and legacy options are mixed Opts: secondaryRepoOptions{ diff --git a/internal/terminal/stdio.go b/internal/terminal/stdio.go index 1ee33e025..827b8426a 100644 --- a/internal/terminal/stdio.go +++ b/internal/terminal/stdio.go @@ -1,13 +1,11 @@ package terminal import ( - "os" - "golang.org/x/term" ) -func StdinIsTerminal() bool { - return term.IsTerminal(int(os.Stdin.Fd())) +func InputIsTerminal(fd uintptr) bool { + return term.IsTerminal(int(fd)) } func OutputIsTerminal(fd uintptr) bool { diff --git a/internal/ui/mock.go b/internal/ui/mock.go index fc5488792..70a95fe1b 100644 --- a/internal/ui/mock.go +++ b/internal/ui/mock.go @@ -25,6 +25,14 @@ func (m *MockTerminal) CanUpdateStatus() bool { return true } +func (m *MockTerminal) InputRaw() io.ReadCloser { + return nil +} + +func (m *MockTerminal) InputIsTerminal() bool { + return true +} + func (m *MockTerminal) OutputRaw() io.Writer { return nil } diff --git a/internal/ui/terminal.go b/internal/ui/terminal.go index 845e36508..8ff5d6f27 100644 --- a/internal/ui/terminal.go +++ b/internal/ui/terminal.go @@ -13,6 +13,8 @@ type Terminal interface { SetStatus(lines []string) // CanUpdateStatus returns true if the terminal can update the status lines. CanUpdateStatus() bool + InputRaw() io.ReadCloser + InputIsTerminal() bool // OutputRaw returns the output writer. Should only be used if there is no // other option. Must not be used in combination with Print, Error, SetStatus // or any other method that writes to the terminal. diff --git a/internal/ui/termstatus/status.go b/internal/ui/termstatus/status.go index a7bd60c31..be3a3ce59 100644 --- a/internal/ui/termstatus/status.go +++ b/internal/ui/termstatus/status.go @@ -17,14 +17,16 @@ var _ ui.Terminal = &Terminal{} // updated. When the output is redirected to a file, the status lines are not // printed. type Terminal struct { + rd io.ReadCloser wr io.Writer fd uintptr errWriter io.Writer msg chan message status chan status + lastStatusLen int + inputIsTerminal bool outputIsTerminal bool canUpdateStatus bool - lastStatusLen int // will be closed when the goroutine which runs Run() terminates, so it'll // yield a default value immediately @@ -56,12 +58,12 @@ type fder interface { // defer cancel() // // do stuff // ``` -func Setup(stdout, stderr io.Writer, quiet bool) (*Terminal, func()) { +func Setup(stdin io.ReadCloser, stdout, stderr io.Writer, quiet bool) (*Terminal, func()) { var wg sync.WaitGroup // only shutdown once cancel is called to ensure that no output is lost cancelCtx, cancel := context.WithCancel(context.Background()) - term := New(stdout, stderr, quiet) + term := New(stdin, stdout, stderr, quiet) wg.Add(1) go func() { defer wg.Done() @@ -82,8 +84,9 @@ func Setup(stdout, stderr io.Writer, quiet bool) (*Terminal, func()) { // normal output (via Print/Printf) are written to wr, error messages are // written to errWriter. If disableStatus is set to true, no status messages // are printed even if the terminal supports it. -func New(wr io.Writer, errWriter io.Writer, disableStatus bool) *Terminal { +func New(rd io.ReadCloser, wr io.Writer, errWriter io.Writer, disableStatus bool) *Terminal { t := &Terminal{ + rd: rd, wr: wr, errWriter: errWriter, msg: make(chan message), @@ -95,6 +98,12 @@ func New(wr io.Writer, errWriter io.Writer, disableStatus bool) *Terminal { return t } + if d, ok := rd.(fder); ok { + if terminal.InputIsTerminal(d.Fd()) { + t.inputIsTerminal = true + } + } + if d, ok := wr.(fder); ok { if terminal.CanUpdateStatus(d.Fd()) { // only use the fancy status code when we're running on a real terminal. @@ -111,6 +120,16 @@ func New(wr io.Writer, errWriter io.Writer, disableStatus bool) *Terminal { return t } +// InputIsTerminal returns whether the input is a terminal. +func (t *Terminal) InputIsTerminal() bool { + return t.inputIsTerminal +} + +// InputRaw returns the input reader. +func (t *Terminal) InputRaw() io.ReadCloser { + return t.rd +} + // CanUpdateStatus return whether the status output is updated in place. func (t *Terminal) CanUpdateStatus() bool { return t.canUpdateStatus diff --git a/internal/ui/termstatus/status_test.go b/internal/ui/termstatus/status_test.go index b12928931..064b02989 100644 --- a/internal/ui/termstatus/status_test.go +++ b/internal/ui/termstatus/status_test.go @@ -13,7 +13,7 @@ import ( func TestSetStatus(t *testing.T) { var buf bytes.Buffer - term := New(&buf, io.Discard, false) + term := New(nil, &buf, io.Discard, false) term.canUpdateStatus = true term.fd = ^uintptr(0)