From 469b7cabadcf17674b41b5bd807bd6ad83d9687f Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Thu, 19 Oct 2023 11:04:33 -0700 Subject: [PATCH] cmd/tailscale: improve taildrop progress printer on Linux (#9878) The progress printer was buggy where it would not print correctly and some of the truncation logic was faulty. The progress printer now prints something like: go1.21.3.linux-amd64.tar.gz 21.53MiB 13.83MiB/s 33.88% ETA 00:00:03 where it shows * the number of bytes transferred so far * the rate of bytes transferred (using a 1-second half-life for an exponentially weighted average) * the progress made as a percentage * the estimated time (as calculated from the rate of bytes transferred) Other changes: * It now correctly prints the progress for very small files * It prints at a faster rate (4Hz instead of 1Hz) * It uses IEC units for byte quantities (to avoid ambiguities of "kb" being kilobits or kilobytes) Updates tailscale/corp#14772 Signed-off-by: Joe Tsai --- cmd/tailscale/cli/file.go | 119 +++++++++++++++++++++++-------------- cmd/tailscale/depaware.txt | 1 + 2 files changed, 74 insertions(+), 46 deletions(-) diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index e583a977f..1c8fc60c1 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -18,7 +18,6 @@ "path" "path/filepath" "strings" - "sync" "sync/atomic" "time" "unicode/utf8" @@ -29,8 +28,11 @@ "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/tailcfg" + tsrate "tailscale.com/tstime/rate" "tailscale.com/util/quarantine" + "tailscale.com/util/truncate" "tailscale.com/version" ) @@ -52,12 +54,12 @@ type countingReader struct { io.Reader - n atomic.Uint64 + n atomic.Int64 } func (c *countingReader) Read(buf []byte) (int, error) { n, err := c.Reader.Read(buf) - c.n.Add(uint64(n)) + c.n.Add(int64(n)) return n, err } @@ -170,75 +172,100 @@ func runCp(ctx context.Context, args []string) error { log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID) } - var ( - done = make(chan struct{}, 1) - wg sync.WaitGroup - ) + var group syncs.WaitGroup + ctxProgress, cancelProgress := context.WithCancel(ctx) + defer cancelProgress() if isatty.IsTerminal(os.Stderr.Fd()) { - go printProgress(&wg, done, fileContents, name, contentLength) - wg.Add(1) + group.Go(func() { progressPrinter(ctxProgress, name, fileContents.n.Load, contentLength) }) } err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents) + cancelProgress() + group.Wait() // wait for progress printer to stop before reporting the error if err != nil { return err } if cpArgs.verbose { log.Printf("sent %q", name) } - done <- struct{}{} - wg.Wait() } return nil } -const vtRestartLine = "\r\x1b[K" +func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64) { + var rateValueFast, rateValueSlow tsrate.Value + rateValueFast.HalfLife = 1 * time.Second // fast response for rate measurement + rateValueSlow.HalfLife = 10 * time.Second // slow response for ETA measurement + var prevContentCount int64 + print := func() { + currContentCount := contentCount() + rateValueFast.Add(float64(currContentCount - prevContentCount)) + rateValueSlow.Add(float64(currContentCount - prevContentCount)) + prevContentCount = currContentCount -func printProgress(wg *sync.WaitGroup, done <-chan struct{}, r *countingReader, name string, contentLength int64) { - defer wg.Done() - var lastBytesRead uint64 + const vtRestartLine = "\r\x1b[K" + fmt.Fprintf(os.Stderr, "%s%s %s %s", + vtRestartLine, + rightPad(name, 36), + leftPad(formatIEC(float64(currContentCount), "B"), len("1023.00MiB")), + leftPad(formatIEC(rateValueFast.Rate(), "B/s"), len("1023.00MiB/s"))) + if contentLength >= 0 { + currContentCount = min(currContentCount, contentLength) // cap at 100% + ratioRemain := float64(currContentCount) / float64(contentLength) + bytesRemain := float64(contentLength - currContentCount) + secsRemain := bytesRemain / rateValueSlow.Rate() + secs := int(min(max(0, secsRemain), 99*60*60+59+60+59)) + fmt.Fprintf(os.Stderr, " %s %s", + leftPad(fmt.Sprintf("%0.2f%%", 100.0*ratioRemain), len("100.00%")), + fmt.Sprintf("ETA %02d:%02d:%02d", secs/60/60, (secs/60)%60, secs%60)) + } + } + tc := time.NewTicker(250 * time.Millisecond) + defer tc.Stop() + print() for { select { - case <-done: + case <-ctx.Done(): + print() fmt.Fprintln(os.Stderr) return - case <-time.After(time.Second): - n := r.n.Load() - contentLengthStr := "???" - if contentLength > 0 { - contentLengthStr = fmt.Sprint(contentLength / 1024) - } - - fmt.Fprintf(os.Stderr, "%s%s\t\t%s", vtRestartLine, padTruncateString(name, 36), padTruncateString(fmt.Sprintf("%d/%s kb", n/1024, contentLengthStr), 16)) - if contentLength > 0 { - fmt.Fprintf(os.Stderr, "\t%.02f%%", float64(n)/float64(contentLength)*100) - } else { - fmt.Fprintf(os.Stderr, "\t-------%%") - } - if lastBytesRead > 0 { - fmt.Fprintf(os.Stderr, "\t%d kb/s", (n-lastBytesRead)/1024) - } else { - fmt.Fprintf(os.Stderr, "\t-------") - } - lastBytesRead = n + case <-tc.C: + print() } } } -func padTruncateString(str string, truncateAt int) string { - if len(str) <= truncateAt { - return str + strings.Repeat(" ", truncateAt-len(str)) - } +func leftPad(s string, n int) string { + s = truncateString(s, n) + return strings.Repeat(" ", max(n-len(s), 0)) + s +} - // Truncate the string, but respect unicode codepoint boundaries. - // As of RFC3629 utf-8 codepoints can be at most 4 bytes wide. - for i := 1; i <= 4 && i < len(str)-truncateAt; i++ { - if utf8.ValidString(str[:truncateAt-i]) { - return str[:truncateAt-i] + "…" - } +func rightPad(s string, n int) string { + s = truncateString(s, n) + return s + strings.Repeat(" ", max(n-len(s), 0)) +} + +func truncateString(s string, n int) string { + if len(s) <= n { + return s + } + return truncate.String(s, max(n-1, 0)) + "…" +} + +func formatIEC(n float64, unit string) string { + switch { + case n < 1<<10: + return fmt.Sprintf("%0.2f%s", n/(1<<0), unit) + case n < 1<<20: + return fmt.Sprintf("%0.2fKi%s", n/(1<<10), unit) + case n < 1<<30: + return fmt.Sprintf("%0.2fMi%s", n/(1<<20), unit) + case n < 1<<40: + return fmt.Sprintf("%0.2fGi%s", n/(1<<30), unit) + default: + return fmt.Sprintf("%0.2fTi%s", n/(1<<40), unit) } - return "" // Should be unreachable } func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNodeID, isOffline bool, err error) { diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 878eaeab9..5908685c6 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -158,6 +158,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/net/dnscache+ tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli + tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli tailscale.com/util/vizerror from tailscale.com/types/ipproto+ 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate