mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-05 14:57:49 +00:00
cmd/tailscale/cli: add progress to tailscale file cp
Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
parent
d1a5757639
commit
e27f4f022e
@ -19,9 +19,12 @@
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/peterbourgon/ff/v3/ffcli"
|
||||
"golang.org/x/time/rate"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
@ -49,6 +52,17 @@
|
||||
},
|
||||
}
|
||||
|
||||
type countingReader struct {
|
||||
io.Reader
|
||||
n atomic.Uint64
|
||||
}
|
||||
|
||||
func (c *countingReader) Read(buf []byte) (int, error) {
|
||||
n, err := c.Reader.Read(buf)
|
||||
c.n.Add(uint64(n))
|
||||
return n, err
|
||||
}
|
||||
|
||||
var fileCpCmd = &ffcli.Command{
|
||||
Name: "cp",
|
||||
ShortUsage: "file cp <files...> <target>:",
|
||||
@ -116,11 +130,11 @@ func runCp(ctx context.Context, args []string) error {
|
||||
}
|
||||
|
||||
for _, fileArg := range files {
|
||||
var fileContents io.Reader
|
||||
var fileContents *countingReader
|
||||
var name = cpArgs.name
|
||||
var contentLength int64 = -1
|
||||
if fileArg == "-" {
|
||||
fileContents = os.Stdin
|
||||
fileContents = &countingReader{Reader: os.Stdin}
|
||||
if name == "" {
|
||||
name, fileContents, err = pickStdinFilename()
|
||||
if err != nil {
|
||||
@ -144,19 +158,29 @@ func runCp(ctx context.Context, args []string) error {
|
||||
return errors.New("directories not supported")
|
||||
}
|
||||
contentLength = fi.Size()
|
||||
fileContents = io.LimitReader(f, contentLength)
|
||||
fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)}
|
||||
if name == "" {
|
||||
name = filepath.Base(fileArg)
|
||||
}
|
||||
|
||||
if envknob.Bool("TS_DEBUG_SLOW_PUSH") {
|
||||
fileContents = &slowReader{r: fileContents}
|
||||
fileContents = &countingReader{Reader: &slowReader{r: fileContents}}
|
||||
}
|
||||
}
|
||||
|
||||
if cpArgs.verbose {
|
||||
log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID)
|
||||
}
|
||||
|
||||
var (
|
||||
done = make(chan struct{}, 1)
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
if isatty.IsTerminal(os.Stderr.Fd()) {
|
||||
go printProgress(&wg, done, fileContents, name, contentLength)
|
||||
wg.Add(1)
|
||||
}
|
||||
|
||||
err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -164,10 +188,61 @@ func runCp(ctx context.Context, args []string) error {
|
||||
if cpArgs.verbose {
|
||||
log.Printf("sent %q", name)
|
||||
}
|
||||
done <- struct{}{}
|
||||
wg.Wait()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const vtRestartLine = "\r\x1b[K"
|
||||
|
||||
func printProgress(wg *sync.WaitGroup, done <-chan struct{}, r *countingReader, name string, contentLength int64) {
|
||||
defer wg.Done()
|
||||
var lastBytesRead uint64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
fmt.Fprintln(os.Stderr)
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
n := r.n.Load()
|
||||
contentLengthStr := "???"
|
||||
if contentLength > 0 {
|
||||
contentLengthStr = fmt.Sprint(contentLength / 1024)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s%s\t\t%s", vtRestartLine, padTruncateString(name, 36), padTruncateString(fmt.Sprintf("%d/%s kb", n/1024, contentLengthStr), 16))
|
||||
if contentLength > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\t%.02f%%", float64(n)/float64(contentLength)*100)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\t-------%%")
|
||||
}
|
||||
if lastBytesRead > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\t%d kb/s", (n-lastBytesRead)/1024)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\t-------")
|
||||
}
|
||||
lastBytesRead = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func padTruncateString(str string, truncateAt int) string {
|
||||
if len(str) <= truncateAt {
|
||||
return str + strings.Repeat(" ", truncateAt-len(str))
|
||||
}
|
||||
|
||||
// Truncate the string, but respect unicode codepoint boundaries.
|
||||
// As of RFC3629 utf-8 codepoints can be at most 4 bytes wide.
|
||||
for i := 1; i <= 4 && i < len(str)-truncateAt; i++ {
|
||||
if utf8.ValidString(str[:truncateAt-i]) {
|
||||
return str[:truncateAt-i] + "…"
|
||||
}
|
||||
}
|
||||
return "" // Should be unreachable
|
||||
}
|
||||
|
||||
func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNodeID, isOffline bool, err error) {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
@ -230,12 +305,12 @@ func ext(b []byte) string {
|
||||
// pickStdinFilename reads a bit of stdin to return a good filename
|
||||
// for its contents. The returned Reader is the concatenation of the
|
||||
// read and unread bits.
|
||||
func pickStdinFilename() (name string, r io.Reader, err error) {
|
||||
func pickStdinFilename() (name string, r *countingReader, err error) {
|
||||
sniff, err := io.ReadAll(io.LimitReader(os.Stdin, maxSniff))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return "stdin" + ext(sniff), io.MultiReader(bytes.NewReader(sniff), os.Stdin), nil
|
||||
return "stdin" + ext(sniff), &countingReader{Reader: io.MultiReader(bytes.NewReader(sniff), os.Stdin)}, nil
|
||||
}
|
||||
|
||||
type slowReader struct {
|
||||
|
Loading…
x
Reference in New Issue
Block a user