Merge 4fe99c6c1f0a5fa6d2f6b93b6bfb27894b7d5c39 into b3455fa99a5e8d07133d5140017ec7c49f032a07

This commit is contained in:
YQ 2025-03-24 19:28:38 -04:00 committed by GitHub
commit bc506ef363
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 487 additions and 75 deletions

View File

@ -19,6 +19,7 @@ import (
"path/filepath"
"strings"
"sync/atomic"
"tailscale.com/taildrop"
"time"
"unicode/utf8"
@ -130,6 +131,7 @@ func runCp(ctx context.Context, args []string) error {
var name = cpArgs.name
var contentLength int64 = -1
if fileArg == "-" {
// sending stdin as single file
fileContents = &countingReader{Reader: os.Stdin}
if name == "" {
name, fileContents, err = pickStdinFilename()
@ -151,16 +153,28 @@ func runCp(ctx context.Context, args []string) error {
return err
}
if fi.IsDir() {
return errors.New("directories not supported")
}
contentLength = fi.Size()
fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)}
if name == "" {
name = filepath.Base(fileArg)
}
// sending a directory
// compress it into a <dir-name>.tscompresseddir TAR archive to send
dirReader, err := taildrop.GetCompressedDirReader(fileArg)
if err != nil {
return err
}
fileContents = &countingReader{Reader: dirReader}
if name == "" {
name = filepath.Base(fileArg)
}
name += taildrop.CompressedDirSuffix
} else {
// sending a single file
contentLength = fi.Size()
fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)}
if name == "" {
name = filepath.Base(fileArg)
}
if envknob.Bool("TS_DEBUG_SLOW_PUSH") {
fileContents = &countingReader{Reader: &slowReader{r: fileContents}}
if envknob.Bool("TS_DEBUG_SLOW_PUSH") {
fileContents = &countingReader{Reader: &slowReader{r: fileContents}}
}
}
}
@ -189,6 +203,9 @@ func runCp(ctx context.Context, args []string) error {
}
func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64) {
// remove internal suffixes from name
name = strings.TrimSuffix(name, taildrop.CompressedDirSuffix)
var rateValueFast, rateValueSlow tsrate.Value
rateValueFast.HalfLife = 1 * time.Second // fast response for rate measurement
rateValueSlow.HalfLife = 10 * time.Second // slow response for ETA measurement
@ -518,20 +535,29 @@ func receiveFile(ctx context.Context, wf apitype.WaitingFile, dir string) (targe
return "", 0, fmt.Errorf("opening inbox file %q: %w", wf.Name, err)
}
defer rc.Close()
f, err := openFileOrSubstitute(dir, wf.Name, getArgs.conflict)
if err != nil {
return "", 0, err
if strings.HasSuffix(wf.Name, taildrop.CompressedDirSuffix) {
// receiving a directory
if err := taildrop.ExtractCompressedDir(rc, dir, getArgs.conflict.String()); err != nil {
return "", 0, err
}
return wf.Name, size, nil
} else {
// receiving a single file
f, err := openFileOrSubstitute(dir, wf.Name, getArgs.conflict)
if err != nil {
return "", 0, err
}
// Apply quarantine attribute before copying
if err := quarantine.SetOnFile(f); err != nil {
return "", 0, fmt.Errorf("failed to apply quarantine attribute to file %v: %v", f.Name(), err)
}
_, err = io.Copy(f, rc)
if err != nil {
f.Close()
return "", 0, fmt.Errorf("failed to write %v: %v", f.Name(), err)
}
return f.Name(), size, f.Close()
}
// Apply quarantine attribute before copying
if err := quarantine.SetOnFile(f); err != nil {
return "", 0, fmt.Errorf("failed to apply quarantine attribute to file %v: %v", f.Name(), err)
}
_, err = io.Copy(f, rc)
if err != nil {
f.Close()
return "", 0, fmt.Errorf("failed to write %v: %v", f.Name(), err)
}
return f.Name(), size, f.Close()
}
func runFileGetOneBatch(ctx context.Context, dir string) []error {

View File

@ -9,6 +9,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"sync"
"time"
@ -121,7 +122,7 @@ func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, len
return 0, redactAndLogError("Create", err)
}
defer func() {
f.Close() // best-effort to cleanup dangling file handles
_ = f.Close() // best-effort to cleanup dangling file handles
if err != nil {
m.deleter.Insert(filepath.Base(partialPath)) // mark partial file for eventual deletion
}
@ -170,68 +171,89 @@ func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, len
}
fileLength := offset + copyLength
// File has been successfully received
inFile.mu.Lock()
inFile.done = true
inFile.mu.Unlock()
// File has been successfully received, rename the partial file
// to the final destination filename. If a file of that name already exists,
// then try multiple times with variations of the filename.
computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) {
return sha256File(partialPath)
})
maxRetries := 10
for ; maxRetries > 0; maxRetries-- {
// Atomically rename the partial file as the destination file if it doesn't exist.
// Otherwise, it returns the length of the current destination file.
// The operation is atomic.
dstLength, err := func() (int64, error) {
m.renameMu.Lock()
defer m.renameMu.Unlock()
switch fi, err := os.Stat(dstPath); {
case os.IsNotExist(err):
return -1, os.Rename(partialPath, dstPath)
case err != nil:
return -1, err
default:
return fi.Size(), nil
// If a directory has been delivered in direct file mode, the archive should be extracted here
// as no "tailscale file get" will be executed manually
if m.opts.DirectFileMode && strings.HasSuffix(dstPath, CompressedDirSuffix) {
dirArchive, err := os.Open(partialPath)
if err != nil {
return 0, redactAndLogError("OpenCompressedDir", err)
}
defer func() {
dirArchive.Close()
// Delete partial archive file here, because deleter does not clean files in DirectFileRoot
if err := os.Remove(partialPath); err != nil {
redactAndLogError("DeleteCompressedDirArchive", err)
}
}()
if err != nil {
return 0, redactAndLogError("Rename", err)
}
if dstLength < 0 {
break // we successfully renamed; so stop
}
// Avoid the final rename if a destination file has the same contents.
//
// Note: this is best effort and copying files from iOS from the Media Library
// results in processing on the iOS side which means the size and shas of the
// same file can be different.
if dstLength == fileLength {
partialSum, err := computePartialSum()
if err != nil {
return 0, redactAndLogError("Rename", err)
}
dstSum, err := sha256File(dstPath)
if err != nil {
return 0, redactAndLogError("Rename", err)
}
if dstSum == partialSum {
if err := os.Remove(partialPath); err != nil {
return 0, redactAndLogError("Remove", err)
// Conflict strategy: Always rename "dir_conflicted" to "dir_conflicted (1)" here
if err := ExtractCompressedDir(dirArchive, m.opts.Dir, CreateNumberedFiles); err != nil {
return 0, redactAndLogError("ExtractCompressedDir", err)
}
} else {
// Normally rename the partial file to the final destination filename.
// If a file of that name already exists, then try multiple times with variations of the filename.
computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) {
return sha256File(partialPath)
})
maxRetries := 10
for ; maxRetries > 0; maxRetries-- {
// Atomically rename the partial file as the destination file if it doesn't exist.
// Otherwise, it returns the length of the current destination file.
// The operation is atomic.
dstLength, err := func() (int64, error) {
m.renameMu.Lock()
defer m.renameMu.Unlock()
switch fi, err := os.Stat(dstPath); {
case os.IsNotExist(err):
return -1, os.Rename(partialPath, dstPath)
case err != nil:
return -1, err
default:
return fi.Size(), nil
}
break // we successfully found a content match; so stop
}()
if err != nil {
return 0, redactAndLogError("Rename", err)
}
if dstLength < 0 {
break // we successfully renamed; so stop
}
}
// Choose a new destination filename and try again.
dstPath = NextFilename(dstPath)
inFile.finalPath = dstPath
}
if maxRetries <= 0 {
return 0, errors.New("too many retries trying to rename partial file")
// Avoid the final rename if a destination file has the same contents.
//
// Note: this is best effort and copying files from iOS from the Media Library
// results in processing on the iOS side which means the size and shas of the
// same file can be different.
if dstLength == fileLength {
partialSum, err := computePartialSum()
if err != nil {
return 0, redactAndLogError("Rename", err)
}
dstSum, err := sha256File(dstPath)
if err != nil {
return 0, redactAndLogError("Rename", err)
}
if dstSum == partialSum {
if err := os.Remove(partialPath); err != nil {
return 0, redactAndLogError("Remove", err)
}
break // we successfully found a content match; so stop
}
}
// Choose a new destination filename and try again.
dstPath = NextFilename(dstPath)
inFile.finalPath = dstPath
}
if maxRetries <= 0 {
return 0, errors.New("too many retries trying to rename partial file")
}
}
m.totalReceived.Add(1)
m.opts.SendFileNotify()

View File

@ -51,6 +51,10 @@ const (
// permitted to be uploaded directly on any platform, like
// partial files.
deletedSuffix = ".deleted"
// CompressedDirSuffix is used to mark a directory that is
// compressed and being transferred as an archive.
CompressedDirSuffix = ".tscompresseddir"
)
// ClientID is an opaque identifier for file resumption.

229
taildrop/utils.go Normal file
View File

@ -0,0 +1,229 @@
package taildrop
import (
"archive/tar"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"tailscale.com/util/quarantine"
)
// GetCompressedDirReader will compress the given directory in TAR format
// returns an io.Reader to get the raw TAR stream
func GetCompressedDirReader(dirPath string) (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
tarWriter := tar.NewWriter(pw)
defer func() {
_ = tarWriter.Close()
_ = pw.Close()
}()
dirPath = filepath.Clean(dirPath)
dirName := filepath.Base(dirPath)
var err error
if dirName == "." || dirName == ".." {
// best effort to get the dir name
dirPath, err = filepath.Abs(dirPath)
if err != nil {
_ = pw.CloseWithError(err)
return
}
dirName = filepath.Base(dirPath)
}
err = filepath.Walk(dirPath, func(path string, fileInfo os.FileInfo, err error) error {
if err != nil {
return err
}
relativePath, err := filepath.Rel(dirPath, path)
if err != nil {
return err
}
pathInTar := filepath.ToSlash(filepath.Join(dirName, relativePath))
// try to resolve symbol link
symbolLinkTarget := ""
if fileInfo.Mode()&os.ModeSymlink != 0 {
symbolLinkTarget, err = os.Readlink(path)
if err != nil {
symbolLinkTarget = ""
}
}
header, err := tar.FileInfoHeader(fileInfo, symbolLinkTarget)
if err != nil {
return err
}
header.Name = pathInTar
if err := tarWriter.WriteHeader(header); err != nil {
return err
}
if !fileInfo.IsDir() {
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
if _, err := io.Copy(tarWriter, file); err != nil {
return err
}
}
return nil
})
if err != nil {
_ = pw.CloseWithError(err)
return
}
}()
return pr, nil
}
const (
SkipOnExist string = "skip"
OverwriteExisting string = "overwrite" // Overwrite any existing file at the target location
CreateNumberedFiles string = "rename" // Create an alternately named file in the style of Chrome Downloads
)
func ReplacePrefix(str string, prefix string, replaceTo string) string {
if strings.HasPrefix(str, prefix) && prefix != replaceTo {
return replaceTo + strings.TrimPrefix(str, prefix)
} else {
return str
}
}
// ExtractCompressedDir will uncompress the given TAR archive
// to destination directory
func ExtractCompressedDir(rc io.ReadCloser, dstDir string, conflictAction string) error {
r := tar.NewReader(rc)
dstDir, err := filepath.Abs(dstDir)
if err != nil {
return err
}
// Conflict check is only needed to be done once for the top-level directory in the archive
// Get first record in archive here, find and solve conflict
header, err := r.Next()
if err != nil {
// including EOF, let the caller know that the archive is empty
return err
}
topLevelDirName := strings.Split(header.Name, "/")[0]
// prevent path traversal
topLevelDir := filepath.Clean(filepath.Join(dstDir, topLevelDirName))
if !strings.HasPrefix(topLevelDir, dstDir) {
return errors.New("Bad filepath in TAR: " + topLevelDir)
}
goodTopLevelDirName, err := processDirConflict(dstDir, topLevelDirName, conflictAction)
if err != nil {
return err
}
for {
// replace top-level dir part in path to avoid possible conflict
currentPathPart := ReplacePrefix(header.Name, topLevelDirName, goodTopLevelDirName)
fpath := filepath.Clean(filepath.Join(dstDir, currentPathPart))
// prevent path traversal
if !strings.HasPrefix(fpath, dstDir) {
return errors.New("Bad filepath in TAR: " + fpath)
}
switch header.Typeflag {
case tar.TypeDir:
// extract a dir
if err := os.MkdirAll(fpath, 0644); err != nil {
return err
}
case tar.TypeReg:
// extract a single file
dir := filepath.Dir(fpath)
fileName := filepath.Base(fpath)
if err := os.MkdirAll(dir, 0644); err != nil {
return err
}
outFile, err := os.OpenFile(filepath.Join(dir, fileName), os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
return err
}
defer outFile.Close()
// Apply quarantine attribute before copying
if err := quarantine.SetOnFile(outFile); err != nil {
return errors.New(fmt.Sprintf("failed to apply quarantine attribute to file %v: %v", fileName, err))
}
if _, err := io.Copy(outFile, r); err != nil {
return err
}
default:
// unsupported type flag, just skip it
}
header, err = r.Next()
if err == io.EOF {
break // extract finished
}
if err != nil {
return err
}
}
return nil
}
// processDirConflict will check and try to solve directory conflict according
// to the strategy conflictAction. Returns the dirName that is able to use, or error.
func processDirConflict(parentDir string, dirName string, conflictAction string) (string, error) {
dir := filepath.Join(parentDir, dirName)
isDirExisting := checkDirExisting(dir)
switch conflictAction {
default:
// This should not happen.
return "", fmt.Errorf("bad conflictAction argument")
case SkipOnExist:
if isDirExisting {
return "", fmt.Errorf("refusing to overwrite directory: %v", dir)
}
return dirName, nil
case OverwriteExisting:
if isDirExisting {
if err := os.RemoveAll(dir); err != nil {
return "", fmt.Errorf("unable to remove target directory: %w", err)
}
}
return dirName, nil
case CreateNumberedFiles:
// It's possible the target directory or filesystem isn't writable by us,
// not just that the target file(s) already exists. For now, give up after
// a limited number of attempts. In future, maybe distinguish this case
// and follow in the style of https://tinyurl.com/chromium100
if !isDirExisting {
return dirName, nil
}
maxAttempts := 100
for i := 1; i < maxAttempts; i++ {
newDirName := numberedDirName(dirName, i)
if !checkDirExisting(filepath.Join(parentDir, newDirName)) {
return newDirName, nil
}
}
return "", fmt.Errorf("unable to find a name for writing %v", dir)
}
}
func checkDirExisting(dir string) bool {
_, statErr := os.Stat(dir)
return statErr == nil
}
func numberedDirName(dir string, i int) string {
return fmt.Sprintf("%s (%d)", dir, i)
}

131
taildrop/utils_test.go Normal file
View File

@ -0,0 +1,131 @@
package taildrop
import (
"archive/tar"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"io"
"os"
"path/filepath"
"slices"
"strings"
"tailscale.com/util/must"
"testing"
)
func TestTarArchiveUtils(t *testing.T) {
equalErrorPrefix := func(err error, prefix string) {
t.Helper()
assert.Error(t, err)
assert.True(t, strings.HasPrefix(err.Error(), prefix))
}
var autoCloser []io.ReadCloser
defer func() {
for _, r := range autoCloser {
_ = r.Close()
}
}()
readerFromFile := func(file string) io.ReadCloser {
t.Helper()
f, err := os.Open(file)
must.Do(err)
autoCloser = append(autoCloser, f)
return f
}
writeToFile := func(reader io.Reader, file string) {
t.Helper()
outFile, err := os.Create(file)
must.Do(err)
defer outFile.Close()
_, err = io.Copy(outFile, reader)
must.Do(err)
}
checkDirectory := func(dir string, want ...string) {
t.Helper()
var got []string
for _, de := range must.Get(os.ReadDir(dir)) {
got = append(got, de.Name())
}
slices.Sort(got)
slices.Sort(want)
if diff := cmp.Diff(got, want); diff != "" {
t.Fatalf("directory mismatch (-got +want):\n%s", diff)
}
}
checkTarArchive := func(tarFile string, want ...string) {
t.Helper()
r := tar.NewReader(readerFromFile(tarFile))
var got []string
for {
header, err := r.Next()
if err == io.EOF {
break // extract finished
}
must.Do(err)
got = append(got, header.Name)
}
slices.Sort(got)
slices.Sort(want)
if diff := cmp.Diff(got, want); diff != "" {
t.Fatalf("TAR archive mismatch (-got +want):\n%s", diff)
}
}
dir := t.TempDir()
must.Do(os.MkdirAll(filepath.Join(dir, "root-dir"), 0644))
must.Do(os.WriteFile(filepath.Join(dir, "root-dir/foo.txt"), []byte("This is foo.txt"), 0644))
must.Do(os.WriteFile(filepath.Join(dir, "root-dir/bar"), []byte("This is bar"), 0644))
must.Do(os.WriteFile(filepath.Join(dir, "root-dir/其他文字.docx"), []byte(""), 0644))
must.Do(os.MkdirAll(filepath.Join(dir, "root-dir/sub-dir"), 0644))
must.Do(os.WriteFile(filepath.Join(dir, "root-dir/sub-dir/buzz.log"), []byte("hello world..."), 0644))
// Test Directory Compression
tarPath := filepath.Join(dir, "root-dir.tscompresseddir")
reader, err := GetCompressedDirReader(filepath.Join(dir, "root-dir"))
must.Do(err)
writeToFile(reader, tarPath)
checkTarArchive(tarPath, "root-dir", "root-dir/foo.txt", "root-dir/bar", "root-dir/其他文字.docx",
"root-dir/sub-dir", "root-dir/sub-dir/buzz.log")
reader, err = GetCompressedDirReader(filepath.Join(dir, "./foo/bar/../../root-dir/sub/.."))
must.Do(err)
writeToFile(reader, tarPath)
checkTarArchive(tarPath, "root-dir", "root-dir/foo.txt", "root-dir/bar", "root-dir/其他文字.docx",
"root-dir/sub-dir", "root-dir/sub-dir/buzz.log")
// Test Archive Extraction
downloadDir := filepath.Join(dir, "test-download")
must.Do(os.MkdirAll(downloadDir, 0644))
// success first time
err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, SkipOnExist)
must.Do(err)
// fail second time, due to SkipOnExist
err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, SkipOnExist)
equalErrorPrefix(err, "refusing to overwrite directory:")
// success again, due to OverwriteExisting
err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, OverwriteExisting)
must.Do(err)
checkDirectory(downloadDir, "root-dir")
checkDirectory(filepath.Join(downloadDir, "root-dir"), "foo.txt", "bar", "其他文字.docx", "sub-dir")
checkDirectory(filepath.Join(downloadDir, "root-dir/sub-dir"), "buzz.log")
// success twice, due to CreateNumberedFiles
err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, CreateNumberedFiles)
must.Do(err)
err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, CreateNumberedFiles)
must.Do(err)
checkDirectory(downloadDir, "root-dir", "root-dir (1)", "root-dir (2)")
checkDirectory(filepath.Join(downloadDir, "root-dir (2)"), "foo.txt", "bar", "其他文字.docx", "sub-dir")
checkDirectory(filepath.Join(downloadDir, "root-dir (2)/sub-dir"), "buzz.log")
}