mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-12 13:48:01 +00:00
cmd/tailscale/cli: add progress to tailscale file cp
Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
@@ -19,9 +19,12 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/mattn/go-isatty"
|
||||||
"github.com/peterbourgon/ff/v3/ffcli"
|
"github.com/peterbourgon/ff/v3/ffcli"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
"tailscale.com/client/tailscale/apitype"
|
"tailscale.com/client/tailscale/apitype"
|
||||||
@@ -49,6 +52,17 @@ var fileCmd = &ffcli.Command{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type countingReader struct {
|
||||||
|
io.Reader
|
||||||
|
n atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *countingReader) Read(buf []byte) (int, error) {
|
||||||
|
n, err := c.Reader.Read(buf)
|
||||||
|
c.n.Add(uint64(n))
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
var fileCpCmd = &ffcli.Command{
|
var fileCpCmd = &ffcli.Command{
|
||||||
Name: "cp",
|
Name: "cp",
|
||||||
ShortUsage: "file cp <files...> <target>:",
|
ShortUsage: "file cp <files...> <target>:",
|
||||||
@@ -116,11 +130,11 @@ func runCp(ctx context.Context, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, fileArg := range files {
|
for _, fileArg := range files {
|
||||||
var fileContents io.Reader
|
var fileContents *countingReader
|
||||||
var name = cpArgs.name
|
var name = cpArgs.name
|
||||||
var contentLength int64 = -1
|
var contentLength int64 = -1
|
||||||
if fileArg == "-" {
|
if fileArg == "-" {
|
||||||
fileContents = os.Stdin
|
fileContents = &countingReader{Reader: os.Stdin}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name, fileContents, err = pickStdinFilename()
|
name, fileContents, err = pickStdinFilename()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,19 +158,29 @@ func runCp(ctx context.Context, args []string) error {
|
|||||||
return errors.New("directories not supported")
|
return errors.New("directories not supported")
|
||||||
}
|
}
|
||||||
contentLength = fi.Size()
|
contentLength = fi.Size()
|
||||||
fileContents = io.LimitReader(f, contentLength)
|
fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = filepath.Base(fileArg)
|
name = filepath.Base(fileArg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if envknob.Bool("TS_DEBUG_SLOW_PUSH") {
|
if envknob.Bool("TS_DEBUG_SLOW_PUSH") {
|
||||||
fileContents = &slowReader{r: fileContents}
|
fileContents = &countingReader{Reader: &slowReader{r: fileContents}}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cpArgs.verbose {
|
if cpArgs.verbose {
|
||||||
log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID)
|
log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
done = make(chan struct{}, 1)
|
||||||
|
wg sync.WaitGroup
|
||||||
|
)
|
||||||
|
if isatty.IsTerminal(os.Stderr.Fd()) {
|
||||||
|
go printProgress(&wg, done, fileContents, name, contentLength)
|
||||||
|
wg.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents)
|
err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -164,10 +188,61 @@ func runCp(ctx context.Context, args []string) error {
|
|||||||
if cpArgs.verbose {
|
if cpArgs.verbose {
|
||||||
log.Printf("sent %q", name)
|
log.Printf("sent %q", name)
|
||||||
}
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const vtRestartLine = "\r\x1b[K"
|
||||||
|
|
||||||
|
func printProgress(wg *sync.WaitGroup, done <-chan struct{}, r *countingReader, name string, contentLength int64) {
|
||||||
|
defer wg.Done()
|
||||||
|
var lastBytesRead uint64
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
return
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
n := r.n.Load()
|
||||||
|
contentLengthStr := "???"
|
||||||
|
if contentLength > 0 {
|
||||||
|
contentLengthStr = fmt.Sprint(contentLength / 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s%s\t\t%s", vtRestartLine, padTruncateString(name, 36), padTruncateString(fmt.Sprintf("%d/%s kb", n/1024, contentLengthStr), 16))
|
||||||
|
if contentLength > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\t%.02f%%", float64(n)/float64(contentLength)*100)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\t-------%%")
|
||||||
|
}
|
||||||
|
if lastBytesRead > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\t%d kb/s", (n-lastBytesRead)/1024)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\t-------")
|
||||||
|
}
|
||||||
|
lastBytesRead = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func padTruncateString(str string, truncateAt int) string {
|
||||||
|
if len(str) <= truncateAt {
|
||||||
|
return str + strings.Repeat(" ", truncateAt-len(str))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate the string, but respect unicode codepoint boundaries.
|
||||||
|
// As of RFC3629 utf-8 codepoints can be at most 4 bytes wide.
|
||||||
|
for i := 1; i <= 4 && i < len(str)-truncateAt; i++ {
|
||||||
|
if utf8.ValidString(str[:truncateAt-i]) {
|
||||||
|
return str[:truncateAt-i] + "…"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "" // Should be unreachable
|
||||||
|
}
|
||||||
|
|
||||||
func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNodeID, isOffline bool, err error) {
|
func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNodeID, isOffline bool, err error) {
|
||||||
ip, err := netip.ParseAddr(ipStr)
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -230,12 +305,12 @@ func ext(b []byte) string {
|
|||||||
// pickStdinFilename reads a bit of stdin to return a good filename
|
// pickStdinFilename reads a bit of stdin to return a good filename
|
||||||
// for its contents. The returned Reader is the concatenation of the
|
// for its contents. The returned Reader is the concatenation of the
|
||||||
// read and unread bits.
|
// read and unread bits.
|
||||||
func pickStdinFilename() (name string, r io.Reader, err error) {
|
func pickStdinFilename() (name string, r *countingReader, err error) {
|
||||||
sniff, err := io.ReadAll(io.LimitReader(os.Stdin, maxSniff))
|
sniff, err := io.ReadAll(io.LimitReader(os.Stdin, maxSniff))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
return "stdin" + ext(sniff), io.MultiReader(bytes.NewReader(sniff), os.Stdin), nil
|
return "stdin" + ext(sniff), &countingReader{Reader: io.MultiReader(bytes.NewReader(sniff), os.Stdin)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type slowReader struct {
|
type slowReader struct {
|
||||||
|
Reference in New Issue
Block a user