cmd/tailscale/cli: add progress to tailscale file cp

Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
Tom DNetto 2022-11-28 18:28:11 -08:00 committed by Tom
parent d1a5757639
commit e27f4f022e

View File

@ -19,9 +19,12 @@
"path"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"unicode/utf8"
"github.com/mattn/go-isatty"
"github.com/peterbourgon/ff/v3/ffcli"
"golang.org/x/time/rate"
"tailscale.com/client/tailscale/apitype"
@ -49,6 +52,17 @@
},
}
type countingReader struct {
io.Reader
n atomic.Uint64
}
func (c *countingReader) Read(buf []byte) (int, error) {
n, err := c.Reader.Read(buf)
c.n.Add(uint64(n))
return n, err
}
var fileCpCmd = &ffcli.Command{
Name: "cp",
ShortUsage: "file cp <files...> <target>:",
@ -116,11 +130,11 @@ func runCp(ctx context.Context, args []string) error {
}
for _, fileArg := range files {
var fileContents io.Reader
var fileContents *countingReader
var name = cpArgs.name
var contentLength int64 = -1
if fileArg == "-" {
fileContents = os.Stdin
fileContents = &countingReader{Reader: os.Stdin}
if name == "" {
name, fileContents, err = pickStdinFilename()
if err != nil {
@ -144,19 +158,29 @@ func runCp(ctx context.Context, args []string) error {
return errors.New("directories not supported")
}
contentLength = fi.Size()
fileContents = io.LimitReader(f, contentLength)
fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)}
if name == "" {
name = filepath.Base(fileArg)
}
if envknob.Bool("TS_DEBUG_SLOW_PUSH") {
fileContents = &slowReader{r: fileContents}
fileContents = &countingReader{Reader: &slowReader{r: fileContents}}
}
}
if cpArgs.verbose {
log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID)
}
var (
done = make(chan struct{}, 1)
wg sync.WaitGroup
)
if isatty.IsTerminal(os.Stderr.Fd()) {
go printProgress(&wg, done, fileContents, name, contentLength)
wg.Add(1)
}
err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents)
if err != nil {
return err
@ -164,10 +188,61 @@ func runCp(ctx context.Context, args []string) error {
if cpArgs.verbose {
log.Printf("sent %q", name)
}
done <- struct{}{}
wg.Wait()
}
return nil
}
const vtRestartLine = "\r\x1b[K"
func printProgress(wg *sync.WaitGroup, done <-chan struct{}, r *countingReader, name string, contentLength int64) {
defer wg.Done()
var lastBytesRead uint64
for {
select {
case <-done:
fmt.Fprintln(os.Stderr)
return
case <-time.After(time.Second):
n := r.n.Load()
contentLengthStr := "???"
if contentLength > 0 {
contentLengthStr = fmt.Sprint(contentLength / 1024)
}
fmt.Fprintf(os.Stderr, "%s%s\t\t%s", vtRestartLine, padTruncateString(name, 36), padTruncateString(fmt.Sprintf("%d/%s kb", n/1024, contentLengthStr), 16))
if contentLength > 0 {
fmt.Fprintf(os.Stderr, "\t%.02f%%", float64(n)/float64(contentLength)*100)
} else {
fmt.Fprintf(os.Stderr, "\t-------%%")
}
if lastBytesRead > 0 {
fmt.Fprintf(os.Stderr, "\t%d kb/s", (n-lastBytesRead)/1024)
} else {
fmt.Fprintf(os.Stderr, "\t-------")
}
lastBytesRead = n
}
}
}
func padTruncateString(str string, truncateAt int) string {
if len(str) <= truncateAt {
return str + strings.Repeat(" ", truncateAt-len(str))
}
// Truncate the string, but respect unicode codepoint boundaries.
// As of RFC3629 utf-8 codepoints can be at most 4 bytes wide.
for i := 1; i <= 4 && i < len(str)-truncateAt; i++ {
if utf8.ValidString(str[:truncateAt-i]) {
return str[:truncateAt-i] + "…"
}
}
return "" // Should be unreachable
}
func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNodeID, isOffline bool, err error) {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
@ -230,12 +305,12 @@ func ext(b []byte) string {
// pickStdinFilename reads a bit of stdin to return a good filename
// for its contents. The returned Reader is the concatenation of the
// read and unread bits.
func pickStdinFilename() (name string, r io.Reader, err error) {
func pickStdinFilename() (name string, r *countingReader, err error) {
sniff, err := io.ReadAll(io.LimitReader(os.Stdin, maxSniff))
if err != nil {
return "", nil, err
}
return "stdin" + ext(sniff), io.MultiReader(bytes.NewReader(sniff), os.Stdin), nil
return "stdin" + ext(sniff), &countingReader{Reader: io.MultiReader(bytes.NewReader(sniff), os.Stdin)}, nil
}
type slowReader struct {