From 0cf20bae6f5dffc1d05533ce90a14dacf6f2622a Mon Sep 17 00:00:00 2001 From: yqs112358 <37969157+yqs112358@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:48:11 +0800 Subject: [PATCH 1/3] taildrop: support directory in cli Give support to directries in `tailscale file` command. A simple implement: - To transfer a directory, the dir was compressed into a TAR archive with `.tscompresseddir` file suffix and sent directly just as a single file. - When executing `tailscale file get` to receive from inbox, suffix of this file will be detected and the TAR will be extracted back to the directory Key changes: - Add `getCompressedDirReader` to compress dir into TAR - Add `extractCompressedDir ` to extract from TAR - Impl `if fi.IsDir()` to send a compressed directory - Detect file suffix in `receiveFile` to extract directory from TAR Signed-off-by: yqs112358 <37969157+yqs112358@users.noreply.github.com> --- cmd/tailscale/cli/file.go | 182 ++++++++++++++++++++++++++++++++++---- taildrop/taildrop.go | 4 + 2 files changed, 167 insertions(+), 19 deletions(-) diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index cd7762446..3e63e188f 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -4,6 +4,7 @@ package cli import ( + "archive/tar" "bytes" "context" "errors" @@ -19,6 +20,7 @@ import ( "path/filepath" "strings" "sync/atomic" + "tailscale.com/taildrop" "time" "unicode/utf8" @@ -129,6 +131,7 @@ func runCp(ctx context.Context, args []string) error { var name = cpArgs.name var contentLength int64 = -1 if fileArg == "-" { + // sending stdin as single file fileContents = &countingReader{Reader: os.Stdin} if name == "" { name, fileContents, err = pickStdinFilename() @@ -150,16 +153,28 @@ func runCp(ctx context.Context, args []string) error { return err } if fi.IsDir() { - return errors.New("directories not supported") - } - contentLength = fi.Size() - fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)} - if name == "" { - name = filepath.Base(fileArg) - } + // sending a directory + // compress it into a .tscompresseddir TAR archive to send + dirReader, err := getCompressedDirReader(fileArg) + if err != nil { + return err + } + fileContents = &countingReader{Reader: dirReader} + if name == "" { + name = filepath.Base(fileArg) + } + name += taildrop.CompressedDirSuffix + } else { + // sending a single file + contentLength = fi.Size() + fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)} + if name == "" { + name = filepath.Base(fileArg) + } - if envknob.Bool("TS_DEBUG_SLOW_PUSH") { - fileContents = &countingReader{Reader: &slowReader{r: fileContents}} + if envknob.Bool("TS_DEBUG_SLOW_PUSH") { + fileContents = &countingReader{Reader: &slowReader{r: fileContents}} + } } } @@ -187,7 +202,63 @@ func runCp(ctx context.Context, args []string) error { return nil } +// getCompressedDirReader will compress the given directory in TAR format +// returns an io.Reader to get the raw TAR stream +func getCompressedDirReader(dirPath string) (io.Reader, error) { + pr, pw := io.Pipe() + + go func() { + tarWriter := tar.NewWriter(pw) + defer func() { + tarWriter.Close() + pw.Close() + }() + + err := filepath.Walk(dirPath, func(path string, fileInfo os.FileInfo, err error) error { + if err != nil { + return err + } + + relativePath := filepath.Clean(strings.TrimPrefix(path, string(filepath.Separator))) + // uniform path splitter + relativePath = filepath.ToSlash(relativePath) + + header, err := tar.FileInfoHeader(fileInfo, relativePath) + if err != nil { + return err + } + header.Name = relativePath + + if err := tarWriter.WriteHeader(header); err != nil { + return err + } + + if !fileInfo.IsDir() { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + if _, err := io.Copy(tarWriter, file); err != nil { + return err + } + } + return nil + }) + + if err != nil { + pw.CloseWithError(err) + } + }() + + return pr, nil +} + func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64) { + // remove internal suffixes from name + name = strings.TrimSuffix(name, taildrop.CompressedDirSuffix) + var rateValueFast, rateValueSlow tsrate.Value rateValueFast.HalfLife = 1 * time.Second // fast response for rate measurement rateValueSlow.HalfLife = 10 * time.Second // slow response for ETA measurement @@ -486,20 +557,93 @@ func receiveFile(ctx context.Context, wf apitype.WaitingFile, dir string) (targe return "", 0, fmt.Errorf("opening inbox file %q: %w", wf.Name, err) } defer rc.Close() - f, err := openFileOrSubstitute(dir, wf.Name, getArgs.conflict) + if strings.HasSuffix(wf.Name, taildrop.CompressedDirSuffix) { + // receiving a directory + if err := extractCompressedDir(rc, dir); err != nil { + return "", 0, err + } + return wf.Name, size, nil + } else { + // receiving a single file + f, err := openFileOrSubstitute(dir, wf.Name, getArgs.conflict) + if err != nil { + return "", 0, err + } + // Apply quarantine attribute before copying + if err := quarantine.SetOnFile(f); err != nil { + return "", 0, fmt.Errorf("failed to apply quarantine attribute to file %v: %v", f.Name(), err) + } + _, err = io.Copy(f, rc) + if err != nil { + f.Close() + return "", 0, fmt.Errorf("failed to write %v: %v", f.Name(), err) + } + return f.Name(), size, f.Close() + } +} + +// extractCompressedDir will uncompress the given TAR archive +// to destination directory +func extractCompressedDir(rc io.ReadCloser, dstDir string) error { + r := tar.NewReader(rc) + + err := os.MkdirAll(dstDir, 0644) if err != nil { - return "", 0, err + return err } - // Apply quarantine attribute before copying - if err := quarantine.SetOnFile(f); err != nil { - return "", 0, fmt.Errorf("failed to apply quarantine attribute to file %v: %v", f.Name(), err) - } - _, err = io.Copy(f, rc) + + dstDir, err = filepath.Abs(dstDir) if err != nil { - f.Close() - return "", 0, fmt.Errorf("failed to write %v: %v", f.Name(), err) + return err } - return f.Name(), size, f.Close() + for { + header, err := r.Next() + if err == io.EOF { + break // extract finished + } + if err != nil { + return err + } + + fpath := filepath.Clean(filepath.Join(dstDir, header.Name)) + // prevent path traversal + if !strings.HasPrefix(fpath, dstDir) { + return errors.New("Bad filepath in TAR: " + fpath) + } + + switch header.Typeflag { + case tar.TypeDir: + // extract a dir + if err := os.MkdirAll(fpath, 0644); err != nil { + return err + } + case tar.TypeReg: + // extract a single file + dir := filepath.Dir(fpath) + fileName := filepath.Base(fpath) + if err := os.MkdirAll(dir, 0644); err != nil { + return err + } + + outFile, err := openFileOrSubstitute(dir, fileName, getArgs.conflict) + if err != nil { + return err + } + defer outFile.Close() + + // Apply quarantine attribute before copying + if err := quarantine.SetOnFile(outFile); err != nil { + return errors.New(fmt.Sprintf("failed to apply quarantine attribute to file %v: %v", fileName, err)) + } + if _, err := io.Copy(outFile, r); err != nil { + return err + } + default: + // unsupported type flag + continue + } + } + return nil } func runFileGetOneBatch(ctx context.Context, dir string) []error { diff --git a/taildrop/taildrop.go b/taildrop/taildrop.go index 9ad0e1a7e..1b4e3e627 100644 --- a/taildrop/taildrop.go +++ b/taildrop/taildrop.go @@ -51,6 +51,10 @@ const ( // permitted to be uploaded directly on any platform, like // partial files. deletedSuffix = ".deleted" + + // CompressedDirSuffix is used to mark a directory that is + // compressed and being transferred as an archive. + CompressedDirSuffix = ".tscompresseddir" ) // ClientID is an opaque identifier for file resumption. From ebe685431c2fb61783619630bdb28d494a832388 Mon Sep 17 00:00:00 2001 From: yqs112358 <37969157+yqs112358@users.noreply.github.com> Date: Sat, 28 Sep 2024 20:43:44 +0800 Subject: [PATCH 2/3] taildrop: extract directory in direct file mode Give support to directories in taildrop direct file mode If a directory has been delivered in direct file mode, the archive should be pre-extracted in `*taildrop.Manager.PutFile`, as no "tailscale file get" command will be executed manually. Incidentally, clients that use direct file mode to receive files (e.g. Android Client) will automatically get directory support in taildrop. Signed-off-by: yqs112358 <37969157+yqs112358@users.noreply.github.com> --- cmd/tailscale/cli/file.go | 122 +------------------- taildrop/send.go | 128 ++++++++++++--------- taildrop/utils.go | 229 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 306 insertions(+), 173 deletions(-) create mode 100644 taildrop/utils.go diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index 3e63e188f..2a673c174 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -4,7 +4,6 @@ package cli import ( - "archive/tar" "bytes" "context" "errors" @@ -155,7 +154,7 @@ func runCp(ctx context.Context, args []string) error { if fi.IsDir() { // sending a directory // compress it into a .tscompresseddir TAR archive to send - dirReader, err := getCompressedDirReader(fileArg) + dirReader, err := taildrop.GetCompressedDirReader(fileArg) if err != nil { return err } @@ -202,59 +201,6 @@ func runCp(ctx context.Context, args []string) error { return nil } -// getCompressedDirReader will compress the given directory in TAR format -// returns an io.Reader to get the raw TAR stream -func getCompressedDirReader(dirPath string) (io.Reader, error) { - pr, pw := io.Pipe() - - go func() { - tarWriter := tar.NewWriter(pw) - defer func() { - tarWriter.Close() - pw.Close() - }() - - err := filepath.Walk(dirPath, func(path string, fileInfo os.FileInfo, err error) error { - if err != nil { - return err - } - - relativePath := filepath.Clean(strings.TrimPrefix(path, string(filepath.Separator))) - // uniform path splitter - relativePath = filepath.ToSlash(relativePath) - - header, err := tar.FileInfoHeader(fileInfo, relativePath) - if err != nil { - return err - } - header.Name = relativePath - - if err := tarWriter.WriteHeader(header); err != nil { - return err - } - - if !fileInfo.IsDir() { - file, err := os.Open(path) - if err != nil { - return err - } - defer file.Close() - - if _, err := io.Copy(tarWriter, file); err != nil { - return err - } - } - return nil - }) - - if err != nil { - pw.CloseWithError(err) - } - }() - - return pr, nil -} - func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64) { // remove internal suffixes from name name = strings.TrimSuffix(name, taildrop.CompressedDirSuffix) @@ -559,7 +505,7 @@ func receiveFile(ctx context.Context, wf apitype.WaitingFile, dir string) (targe defer rc.Close() if strings.HasSuffix(wf.Name, taildrop.CompressedDirSuffix) { // receiving a directory - if err := extractCompressedDir(rc, dir); err != nil { + if err := taildrop.ExtractCompressedDir(rc, dir, getArgs.conflict.String()); err != nil { return "", 0, err } return wf.Name, size, nil @@ -582,70 +528,6 @@ func receiveFile(ctx context.Context, wf apitype.WaitingFile, dir string) (targe } } -// extractCompressedDir will uncompress the given TAR archive -// to destination directory -func extractCompressedDir(rc io.ReadCloser, dstDir string) error { - r := tar.NewReader(rc) - - err := os.MkdirAll(dstDir, 0644) - if err != nil { - return err - } - - dstDir, err = filepath.Abs(dstDir) - if err != nil { - return err - } - for { - header, err := r.Next() - if err == io.EOF { - break // extract finished - } - if err != nil { - return err - } - - fpath := filepath.Clean(filepath.Join(dstDir, header.Name)) - // prevent path traversal - if !strings.HasPrefix(fpath, dstDir) { - return errors.New("Bad filepath in TAR: " + fpath) - } - - switch header.Typeflag { - case tar.TypeDir: - // extract a dir - if err := os.MkdirAll(fpath, 0644); err != nil { - return err - } - case tar.TypeReg: - // extract a single file - dir := filepath.Dir(fpath) - fileName := filepath.Base(fpath) - if err := os.MkdirAll(dir, 0644); err != nil { - return err - } - - outFile, err := openFileOrSubstitute(dir, fileName, getArgs.conflict) - if err != nil { - return err - } - defer outFile.Close() - - // Apply quarantine attribute before copying - if err := quarantine.SetOnFile(outFile); err != nil { - return errors.New(fmt.Sprintf("failed to apply quarantine attribute to file %v: %v", fileName, err)) - } - if _, err := io.Copy(outFile, r); err != nil { - return err - } - default: - // unsupported type flag - continue - } - } - return nil -} - func runFileGetOneBatch(ctx context.Context, dir string) []error { var wfs []apitype.WaitingFile var err error diff --git a/taildrop/send.go b/taildrop/send.go index 0dff71b24..c5f74ad39 100644 --- a/taildrop/send.go +++ b/taildrop/send.go @@ -9,6 +9,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" "time" @@ -121,7 +122,7 @@ func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, len return 0, redactAndLogError("Create", err) } defer func() { - f.Close() // best-effort to cleanup dangling file handles + _ = f.Close() // best-effort to cleanup dangling file handles if err != nil { m.deleter.Insert(filepath.Base(partialPath)) // mark partial file for eventual deletion } @@ -170,68 +171,89 @@ func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, len } fileLength := offset + copyLength + // File has been successfully received inFile.mu.Lock() inFile.done = true inFile.mu.Unlock() - // File has been successfully received, rename the partial file - // to the final destination filename. If a file of that name already exists, - // then try multiple times with variations of the filename. - computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) { - return sha256File(partialPath) - }) - maxRetries := 10 - for ; maxRetries > 0; maxRetries-- { - // Atomically rename the partial file as the destination file if it doesn't exist. - // Otherwise, it returns the length of the current destination file. - // The operation is atomic. - dstLength, err := func() (int64, error) { - m.renameMu.Lock() - defer m.renameMu.Unlock() - switch fi, err := os.Stat(dstPath); { - case os.IsNotExist(err): - return -1, os.Rename(partialPath, dstPath) - case err != nil: - return -1, err - default: - return fi.Size(), nil + // If a directory has been delivered in direct file mode, the archive should be extracted here + // as no "tailscale file get" will be executed manually + if m.opts.DirectFileMode && strings.HasSuffix(dstPath, CompressedDirSuffix) { + dirArchive, err := os.Open(partialPath) + if err != nil { + return 0, redactAndLogError("OpenCompressedDir", err) + } + defer func() { + dirArchive.Close() + // Delete partial archive file here, because deleter does not clean files in DirectFileRoot + if err := os.Remove(partialPath); err != nil { + redactAndLogError("DeleteCompressedDirArchive", err) } }() - if err != nil { - return 0, redactAndLogError("Rename", err) - } - if dstLength < 0 { - break // we successfully renamed; so stop - } - // Avoid the final rename if a destination file has the same contents. - // - // Note: this is best effort and copying files from iOS from the Media Library - // results in processing on the iOS side which means the size and shas of the - // same file can be different. - if dstLength == fileLength { - partialSum, err := computePartialSum() - if err != nil { - return 0, redactAndLogError("Rename", err) - } - dstSum, err := sha256File(dstPath) - if err != nil { - return 0, redactAndLogError("Rename", err) - } - if dstSum == partialSum { - if err := os.Remove(partialPath); err != nil { - return 0, redactAndLogError("Remove", err) + // Conflict strategy: Always rename "dir_conflicted" to "dir_conflicted (1)" here + if err := ExtractCompressedDir(dirArchive, m.opts.Dir, CreateNumberedFiles); err != nil { + return 0, redactAndLogError("ExtractCompressedDir", err) + } + } else { + // Normally rename the partial file to the final destination filename. + // If a file of that name already exists, then try multiple times with variations of the filename. + computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) { + return sha256File(partialPath) + }) + maxRetries := 10 + for ; maxRetries > 0; maxRetries-- { + // Atomically rename the partial file as the destination file if it doesn't exist. + // Otherwise, it returns the length of the current destination file. + // The operation is atomic. + dstLength, err := func() (int64, error) { + m.renameMu.Lock() + defer m.renameMu.Unlock() + switch fi, err := os.Stat(dstPath); { + case os.IsNotExist(err): + return -1, os.Rename(partialPath, dstPath) + case err != nil: + return -1, err + default: + return fi.Size(), nil } - break // we successfully found a content match; so stop + }() + if err != nil { + return 0, redactAndLogError("Rename", err) + } + if dstLength < 0 { + break // we successfully renamed; so stop } - } - // Choose a new destination filename and try again. - dstPath = NextFilename(dstPath) - inFile.finalPath = dstPath - } - if maxRetries <= 0 { - return 0, errors.New("too many retries trying to rename partial file") + // Avoid the final rename if a destination file has the same contents. + // + // Note: this is best effort and copying files from iOS from the Media Library + // results in processing on the iOS side which means the size and shas of the + // same file can be different. + if dstLength == fileLength { + partialSum, err := computePartialSum() + if err != nil { + return 0, redactAndLogError("Rename", err) + } + dstSum, err := sha256File(dstPath) + if err != nil { + return 0, redactAndLogError("Rename", err) + } + if dstSum == partialSum { + if err := os.Remove(partialPath); err != nil { + return 0, redactAndLogError("Remove", err) + } + break // we successfully found a content match; so stop + } + } + + // Choose a new destination filename and try again. + dstPath = NextFilename(dstPath) + inFile.finalPath = dstPath + } + if maxRetries <= 0 { + return 0, errors.New("too many retries trying to rename partial file") + } } m.totalReceived.Add(1) m.opts.SendFileNotify() diff --git a/taildrop/utils.go b/taildrop/utils.go new file mode 100644 index 000000000..bd6136a87 --- /dev/null +++ b/taildrop/utils.go @@ -0,0 +1,229 @@ +package taildrop + +import ( + "archive/tar" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "tailscale.com/util/quarantine" +) + +// GetCompressedDirReader will compress the given directory in TAR format +// returns an io.Reader to get the raw TAR stream +func GetCompressedDirReader(dirPath string) (io.Reader, error) { + pr, pw := io.Pipe() + + go func() { + tarWriter := tar.NewWriter(pw) + defer func() { + _ = tarWriter.Close() + _ = pw.Close() + }() + + dirPath = filepath.Clean(dirPath) + dirName := filepath.Base(dirPath) + var err error + if dirName == "." || dirName == ".." { + // best effort to get the dir name + dirPath, err = filepath.Abs(dirPath) + if err != nil { + _ = pw.CloseWithError(err) + return + } + dirName = filepath.Base(dirPath) + } + err = filepath.Walk(dirPath, func(path string, fileInfo os.FileInfo, err error) error { + if err != nil { + return err + } + relativePath, err := filepath.Rel(dirPath, path) + if err != nil { + return err + } + pathInTar := filepath.ToSlash(filepath.Join(dirName, relativePath)) + + // try to resolve symbol link + symbolLinkTarget := "" + if fileInfo.Mode()&os.ModeSymlink != 0 { + symbolLinkTarget, err = os.Readlink(path) + if err != nil { + symbolLinkTarget = "" + } + } + + header, err := tar.FileInfoHeader(fileInfo, symbolLinkTarget) + if err != nil { + return err + } + header.Name = pathInTar + if err := tarWriter.WriteHeader(header); err != nil { + return err + } + + if !fileInfo.IsDir() { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + if _, err := io.Copy(tarWriter, file); err != nil { + return err + } + } + return nil + }) + if err != nil { + _ = pw.CloseWithError(err) + return + } + }() + + return pr, nil +} + +const ( + SkipOnExist string = "skip" + OverwriteExisting string = "overwrite" // Overwrite any existing file at the target location + CreateNumberedFiles string = "rename" // Create an alternately named file in the style of Chrome Downloads +) + +func ReplacePrefix(str string, prefix string, replaceTo string) string { + if strings.HasPrefix(str, prefix) && prefix != replaceTo { + return replaceTo + strings.TrimPrefix(str, prefix) + } else { + return str + } +} + +// ExtractCompressedDir will uncompress the given TAR archive +// to destination directory +func ExtractCompressedDir(rc io.ReadCloser, dstDir string, conflictAction string) error { + r := tar.NewReader(rc) + + dstDir, err := filepath.Abs(dstDir) + if err != nil { + return err + } + + // Conflict check is only needed to be done once for the top-level directory in the archive + // Get first record in archive here, find and solve conflict + header, err := r.Next() + if err != nil { + // including EOF, let the caller know that the archive is empty + return err + } + topLevelDirName := strings.Split(header.Name, "/")[0] + // prevent path traversal + topLevelDir := filepath.Clean(filepath.Join(dstDir, topLevelDirName)) + if !strings.HasPrefix(topLevelDir, dstDir) { + return errors.New("Bad filepath in TAR: " + topLevelDir) + } + goodTopLevelDirName, err := processDirConflict(dstDir, topLevelDirName, conflictAction) + if err != nil { + return err + } + + for { + // replace top-level dir part in path to avoid possible conflict + currentPathPart := ReplacePrefix(header.Name, topLevelDirName, goodTopLevelDirName) + + fpath := filepath.Clean(filepath.Join(dstDir, currentPathPart)) + // prevent path traversal + if !strings.HasPrefix(fpath, dstDir) { + return errors.New("Bad filepath in TAR: " + fpath) + } + + switch header.Typeflag { + case tar.TypeDir: + // extract a dir + if err := os.MkdirAll(fpath, 0644); err != nil { + return err + } + case tar.TypeReg: + // extract a single file + dir := filepath.Dir(fpath) + fileName := filepath.Base(fpath) + if err := os.MkdirAll(dir, 0644); err != nil { + return err + } + outFile, err := os.OpenFile(filepath.Join(dir, fileName), os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + return err + } + defer outFile.Close() + + // Apply quarantine attribute before copying + if err := quarantine.SetOnFile(outFile); err != nil { + return errors.New(fmt.Sprintf("failed to apply quarantine attribute to file %v: %v", fileName, err)) + } + if _, err := io.Copy(outFile, r); err != nil { + return err + } + default: + // unsupported type flag, just skip it + } + + header, err = r.Next() + if err == io.EOF { + break // extract finished + } + if err != nil { + return err + } + } + return nil +} + +// processDirConflict will check and try to solve directory conflict according +// to the strategy conflictAction. Returns the dirName that is able to use, or error. +func processDirConflict(parentDir string, dirName string, conflictAction string) (string, error) { + dir := filepath.Join(parentDir, dirName) + isDirExisting := checkDirExisting(dir) + + switch conflictAction { + default: + // This should not happen. + return "", fmt.Errorf("bad conflictAction argument") + case SkipOnExist: + if isDirExisting { + return "", fmt.Errorf("refusing to overwrite directory: %v", dir) + } + return dirName, nil + case OverwriteExisting: + if isDirExisting { + if err := os.RemoveAll(dir); err != nil { + return "", fmt.Errorf("unable to remove target directory: %w", err) + } + } + return dirName, nil + case CreateNumberedFiles: + // It's possible the target directory or filesystem isn't writable by us, + // not just that the target file(s) already exists. For now, give up after + // a limited number of attempts. In future, maybe distinguish this case + // and follow in the style of https://tinyurl.com/chromium100 + if !isDirExisting { + return dirName, nil + } + maxAttempts := 100 + for i := 1; i < maxAttempts; i++ { + newDirName := numberedDirName(dirName, i) + if !checkDirExisting(filepath.Join(parentDir, newDirName)) { + return newDirName, nil + } + } + return "", fmt.Errorf("unable to find a name for writing %v", dir) + } +} + +func checkDirExisting(dir string) bool { + _, statErr := os.Stat(dir) + return statErr == nil +} + +func numberedDirName(dir string, i int) string { + return fmt.Sprintf("%s (%d)", dir, i) +} From 4fe99c6c1f0a5fa6d2f6b93b6bfb27894b7d5c39 Mon Sep 17 00:00:00 2001 From: yqs112358 <37969157+yqs112358@users.noreply.github.com> Date: Mon, 30 Sep 2024 09:15:58 +0800 Subject: [PATCH 3/3] taildrop: add test for directory utils Add test for directory utils Signed-off-by: yqs112358 <37969157+yqs112358@users.noreply.github.com> --- taildrop/utils_test.go | 131 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 taildrop/utils_test.go diff --git a/taildrop/utils_test.go b/taildrop/utils_test.go new file mode 100644 index 000000000..b54ea6899 --- /dev/null +++ b/taildrop/utils_test.go @@ -0,0 +1,131 @@ +package taildrop + +import ( + "archive/tar" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "io" + "os" + "path/filepath" + "slices" + "strings" + "tailscale.com/util/must" + "testing" +) + +func TestTarArchiveUtils(t *testing.T) { + equalErrorPrefix := func(err error, prefix string) { + t.Helper() + assert.Error(t, err) + assert.True(t, strings.HasPrefix(err.Error(), prefix)) + } + + var autoCloser []io.ReadCloser + defer func() { + for _, r := range autoCloser { + _ = r.Close() + } + }() + + readerFromFile := func(file string) io.ReadCloser { + t.Helper() + f, err := os.Open(file) + must.Do(err) + autoCloser = append(autoCloser, f) + return f + } + + writeToFile := func(reader io.Reader, file string) { + t.Helper() + outFile, err := os.Create(file) + must.Do(err) + defer outFile.Close() + + _, err = io.Copy(outFile, reader) + must.Do(err) + } + + checkDirectory := func(dir string, want ...string) { + t.Helper() + var got []string + for _, de := range must.Get(os.ReadDir(dir)) { + got = append(got, de.Name()) + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("directory mismatch (-got +want):\n%s", diff) + } + } + + checkTarArchive := func(tarFile string, want ...string) { + t.Helper() + r := tar.NewReader(readerFromFile(tarFile)) + + var got []string + for { + header, err := r.Next() + if err == io.EOF { + break // extract finished + } + must.Do(err) + got = append(got, header.Name) + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("TAR archive mismatch (-got +want):\n%s", diff) + } + } + + dir := t.TempDir() + must.Do(os.MkdirAll(filepath.Join(dir, "root-dir"), 0644)) + must.Do(os.WriteFile(filepath.Join(dir, "root-dir/foo.txt"), []byte("This is foo.txt"), 0644)) + must.Do(os.WriteFile(filepath.Join(dir, "root-dir/bar"), []byte("This is bar"), 0644)) + must.Do(os.WriteFile(filepath.Join(dir, "root-dir/其他文字.docx"), []byte(""), 0644)) + must.Do(os.MkdirAll(filepath.Join(dir, "root-dir/sub-dir"), 0644)) + must.Do(os.WriteFile(filepath.Join(dir, "root-dir/sub-dir/buzz.log"), []byte("hello world..."), 0644)) + + // Test Directory Compression + tarPath := filepath.Join(dir, "root-dir.tscompresseddir") + + reader, err := GetCompressedDirReader(filepath.Join(dir, "root-dir")) + must.Do(err) + writeToFile(reader, tarPath) + checkTarArchive(tarPath, "root-dir", "root-dir/foo.txt", "root-dir/bar", "root-dir/其他文字.docx", + "root-dir/sub-dir", "root-dir/sub-dir/buzz.log") + + reader, err = GetCompressedDirReader(filepath.Join(dir, "./foo/bar/../../root-dir/sub/..")) + must.Do(err) + writeToFile(reader, tarPath) + checkTarArchive(tarPath, "root-dir", "root-dir/foo.txt", "root-dir/bar", "root-dir/其他文字.docx", + "root-dir/sub-dir", "root-dir/sub-dir/buzz.log") + + // Test Archive Extraction + downloadDir := filepath.Join(dir, "test-download") + must.Do(os.MkdirAll(downloadDir, 0644)) + + // success first time + err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, SkipOnExist) + must.Do(err) + // fail second time, due to SkipOnExist + err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, SkipOnExist) + equalErrorPrefix(err, "refusing to overwrite directory:") + // success again, due to OverwriteExisting + err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, OverwriteExisting) + must.Do(err) + + checkDirectory(downloadDir, "root-dir") + checkDirectory(filepath.Join(downloadDir, "root-dir"), "foo.txt", "bar", "其他文字.docx", "sub-dir") + checkDirectory(filepath.Join(downloadDir, "root-dir/sub-dir"), "buzz.log") + + // success twice, due to CreateNumberedFiles + err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, CreateNumberedFiles) + must.Do(err) + err = ExtractCompressedDir(readerFromFile(tarPath), downloadDir, CreateNumberedFiles) + must.Do(err) + + checkDirectory(downloadDir, "root-dir", "root-dir (1)", "root-dir (2)") + checkDirectory(filepath.Join(downloadDir, "root-dir (2)"), "foo.txt", "bar", "其他文字.docx", "sub-dir") + checkDirectory(filepath.Join(downloadDir, "root-dir (2)/sub-dir"), "buzz.log") +}