mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
taildrop: add logic for resuming partial files (#9785)
We add the following API: * type FileChecksums * type Checksum * func Manager.PartialFiles * func Manager.HashPartialFile * func ResumeReader The Manager methods provide the ability to query for partial files and retrieve a list of checksums for a given partial file. The ResumeReader function is a helper that wraps an io.Reader to discard content that is identical locally and remotely. The FileChecksums type represents the checksums of a file and is safe to JSON marshal and send over the wire. Updates tailscale/corp#14772 Signed-off-by: Joe Tsai <joetsai@digital-static.net> Co-authored-by: Rhea Ghosh <rhea@tailscale.com>
This commit is contained in:
parent
24f322bc43
commit
b1867eb23f
@ -3552,7 +3552,7 @@ func (b *LocalBackend) initPeerAPIListener() {
|
||||
b: b,
|
||||
taildrop: &taildrop.Manager{
|
||||
Logf: b.logf,
|
||||
Clock: b.clock,
|
||||
Clock: tstime.DefaultClock{b.clock},
|
||||
Dir: fileRoot,
|
||||
DirectFileMode: b.directFileRoot != "",
|
||||
AvoidFinalRename: !b.directFileDoFinalRename,
|
||||
|
@ -541,8 +541,7 @@ func(t *testing.T, env *peerAPITestEnv) {
|
||||
rootDir = t.TempDir()
|
||||
if e.ph.ps.taildrop == nil {
|
||||
e.ph.ps.taildrop = &taildrop.Manager{
|
||||
Logf: e.logBuf.Logf,
|
||||
Clock: &tstest.Clock{},
|
||||
Logf: e.logBuf.Logf,
|
||||
}
|
||||
}
|
||||
e.ph.ps.taildrop.Dir = rootDir
|
||||
@ -585,9 +584,8 @@ func TestFileDeleteRace(t *testing.T) {
|
||||
clock: &tstest.Clock{},
|
||||
},
|
||||
taildrop: &taildrop.Manager{
|
||||
Logf: t.Logf,
|
||||
Clock: &tstest.Clock{},
|
||||
Dir: dir,
|
||||
Logf: t.Logf,
|
||||
Dir: dir,
|
||||
},
|
||||
}
|
||||
ph := &peerAPIHandler{
|
||||
|
211
taildrop/resume.go
Normal file
211
taildrop/resume.go
Normal file
@ -0,0 +1,211 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package taildrop
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
blockSize = int64(64 << 10)
|
||||
hashAlgorithm = "sha256"
|
||||
)
|
||||
|
||||
// FileChecksums represents checksums into partially received file.
|
||||
type FileChecksums struct {
|
||||
// Offset is the offset into the file.
|
||||
Offset int64 `json:"offset"`
|
||||
// Length is the length of content being hashed in the file.
|
||||
Length int64 `json:"length"`
|
||||
// Checksums is a list of checksums of BlockSize-sized blocks
|
||||
// starting from Offset. The number of checksums is the Length
|
||||
// divided by BlockSize rounded up to the nearest integer.
|
||||
// All blocks except for the last one are guaranteed to be checksums
|
||||
// over BlockSize-sized blocks.
|
||||
Checksums []Checksum `json:"checksums"`
|
||||
// Algorithm is the hashing algorithm used to compute checksums.
|
||||
Algorithm string `json:"algorithm"` // always "sha256" for now
|
||||
// BlockSize is the size of each block.
|
||||
// The last block may be smaller than this, but never zero.
|
||||
BlockSize int64 `json:"blockSize"` // always (64<<10) for now
|
||||
}
|
||||
|
||||
// Checksum is an opaque checksum that is comparable.
|
||||
type Checksum struct{ cs [sha256.Size]byte }
|
||||
|
||||
func hash(b []byte) Checksum {
|
||||
return Checksum{sha256.Sum256(b)}
|
||||
}
|
||||
func (cs Checksum) String() string {
|
||||
return hex.EncodeToString(cs.cs[:])
|
||||
}
|
||||
func (cs Checksum) AppendText(b []byte) ([]byte, error) {
|
||||
return hexAppendEncode(b, cs.cs[:]), nil
|
||||
}
|
||||
func (cs Checksum) MarshalText() ([]byte, error) {
|
||||
return hexAppendEncode(nil, cs.cs[:]), nil
|
||||
}
|
||||
func (cs *Checksum) UnmarshalText(b []byte) error {
|
||||
if len(b) != 2*len(cs.cs) {
|
||||
return fmt.Errorf("invalid hex length: %d", len(b))
|
||||
}
|
||||
_, err := hex.Decode(cs.cs[:], b)
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(https://go.dev/issue/53693): Use hex.AppendEncode instead.
|
||||
func hexAppendEncode(dst, src []byte) []byte {
|
||||
n := hex.EncodedLen(len(src))
|
||||
dst = slices.Grow(dst, n)
|
||||
hex.Encode(dst[len(dst):][:n], src)
|
||||
return dst[:len(dst)+n]
|
||||
}
|
||||
|
||||
// PartialFiles returns a list of partial files in [Handler.Dir]
|
||||
// that were sent (or is actively being sent) by the provided id.
|
||||
func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) {
|
||||
if m.Dir == "" {
|
||||
return ret, ErrNoTaildrop
|
||||
}
|
||||
if m.DirectFileMode && m.AvoidFinalRename {
|
||||
return nil, nil // resuming is not supported for users that peek at our file structure
|
||||
}
|
||||
|
||||
f, err := os.Open(m.Dir)
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
suffix := id.partialSuffix()
|
||||
for {
|
||||
des, err := f.ReadDir(10)
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
for _, de := range des {
|
||||
if name := de.Name(); strings.HasSuffix(name, suffix) {
|
||||
ret = append(ret, name)
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HashPartialFile hashes the contents of a partial file sent by id,
|
||||
// starting at the specified offset and for the specified length.
|
||||
// If length is negative, it hashes the entire file.
|
||||
// If the length exceeds the remaining file length, then it hashes until EOF.
|
||||
// If [FileHashes.Length] is less than length and no error occurred,
|
||||
// then it implies that all remaining content in the file has been hashed.
|
||||
func (m *Manager) HashPartialFile(id ClientID, baseName string, offset, length int64) (FileChecksums, error) {
|
||||
if m.Dir == "" {
|
||||
return FileChecksums{}, ErrNoTaildrop
|
||||
}
|
||||
if m.DirectFileMode && m.AvoidFinalRename {
|
||||
return FileChecksums{}, nil // resuming is not supported for users that peek at our file structure
|
||||
}
|
||||
|
||||
dstFile, err := m.joinDir(baseName)
|
||||
if err != nil {
|
||||
return FileChecksums{}, err
|
||||
}
|
||||
f, err := os.Open(dstFile + id.partialSuffix())
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return FileChecksums{}, nil
|
||||
}
|
||||
return FileChecksums{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Seek(offset, io.SeekStart); err != nil {
|
||||
return FileChecksums{}, err
|
||||
}
|
||||
checksums := FileChecksums{
|
||||
Offset: offset,
|
||||
Algorithm: hashAlgorithm,
|
||||
BlockSize: blockSize,
|
||||
}
|
||||
b := make([]byte, blockSize) // TODO: Pool this?
|
||||
r := io.LimitReader(f, length)
|
||||
for {
|
||||
switch n, err := io.ReadFull(r, b); {
|
||||
case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF:
|
||||
return checksums, err
|
||||
case n == 0:
|
||||
return checksums, nil
|
||||
default:
|
||||
checksums.Checksums = append(checksums.Checksums, hash(b[:n]))
|
||||
checksums.Length += int64(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResumeReader reads and discards the leading content of r
|
||||
// that matches the content based on the checksums that exist.
|
||||
// It returns the number of bytes consumed,
|
||||
// and returns an [io.Reader] representing the remaining content.
|
||||
func ResumeReader(r io.Reader, hashFile func(offset, length int64) (FileChecksums, error)) (int64, io.Reader, error) {
|
||||
if hashFile == nil {
|
||||
return 0, r, nil
|
||||
}
|
||||
|
||||
// Ask for checksums of a particular content length,
|
||||
// where the amount of memory needed to represent the checksums themselves
|
||||
// is exactly equal to the blockSize.
|
||||
numBlocks := blockSize / sha256.Size
|
||||
hashLength := blockSize * numBlocks
|
||||
|
||||
var offset int64
|
||||
b := make([]byte, 0, blockSize)
|
||||
for {
|
||||
// Request a list of checksums for the partial file starting at offset.
|
||||
checksums, err := hashFile(offset, hashLength)
|
||||
if len(checksums.Checksums) == 0 || err != nil {
|
||||
return offset, io.MultiReader(bytes.NewReader(b), r), err
|
||||
} else if checksums.BlockSize != blockSize || checksums.Algorithm != hashAlgorithm {
|
||||
return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm")
|
||||
}
|
||||
|
||||
// Read from r, comparing each block with the provided checksums.
|
||||
for _, want := range checksums.Checksums {
|
||||
// Read a block from r.
|
||||
n, err := io.ReadFull(r, b[:blockSize])
|
||||
b = b[:n]
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
err = nil
|
||||
}
|
||||
if len(b) == 0 || err != nil {
|
||||
// This should not occur in practice.
|
||||
// It implies that an error occurred reading r,
|
||||
// or that the partial file on the remote side is fully complete.
|
||||
return offset, io.MultiReader(bytes.NewReader(b), r), err
|
||||
}
|
||||
|
||||
// Compare the local and remote block checksums.
|
||||
// If it mismatches, then resume from this point.
|
||||
got := hash(b)
|
||||
if got != want {
|
||||
return offset, io.MultiReader(bytes.NewReader(b), r), nil
|
||||
}
|
||||
offset += int64(len(b))
|
||||
b = b[:0]
|
||||
}
|
||||
|
||||
// We hashed the remainder of the partial file, so stop.
|
||||
if checksums.Length < hashLength {
|
||||
return offset, io.MultiReader(bytes.NewReader(b), r), nil
|
||||
}
|
||||
}
|
||||
}
|
63
taildrop/resume_test.go
Normal file
63
taildrop/resume_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package taildrop
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
func TestResume(t *testing.T) {
|
||||
oldBlockSize := blockSize
|
||||
defer func() { blockSize = oldBlockSize }()
|
||||
blockSize = 256
|
||||
|
||||
m := Manager{Logf: t.Logf, Dir: t.TempDir()}
|
||||
|
||||
rn := rand.New(rand.NewSource(0))
|
||||
want := make([]byte, 12345)
|
||||
must.Get(io.ReadFull(rn, want))
|
||||
|
||||
t.Run("resume-noop", func(t *testing.T) {
|
||||
r := io.Reader(bytes.NewReader(want))
|
||||
offset, r, err := ResumeReader(r, func(offset, length int64) (FileChecksums, error) {
|
||||
return m.HashPartialFile("", "foo", offset, length)
|
||||
})
|
||||
must.Do(err)
|
||||
must.Get(m.PutFile("", "foo", r, offset, -1))
|
||||
got := must.Get(os.ReadFile(must.Get(m.joinDir("foo"))))
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("content mismatches")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("resume-retry", func(t *testing.T) {
|
||||
rn := rand.New(rand.NewSource(0))
|
||||
for {
|
||||
r := io.Reader(bytes.NewReader(want))
|
||||
offset, r, err := ResumeReader(r, func(offset, length int64) (FileChecksums, error) {
|
||||
return m.HashPartialFile("", "foo", offset, length)
|
||||
})
|
||||
must.Do(err)
|
||||
numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1)
|
||||
if offset < int64(len(want)) {
|
||||
r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe))
|
||||
}
|
||||
if _, err := m.PutFile("", "foo", r, offset, -1); err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
got := must.Get(os.ReadFile(must.Get(m.joinDir("foo"))))
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("content mismatches")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
@ -167,9 +167,9 @@ func (m *Manager) DeleteFile(baseName string) error {
|
||||
if m.DirectFileMode {
|
||||
return errors.New("deletes not allowed in direct mode")
|
||||
}
|
||||
path, ok := m.joinDir(baseName)
|
||||
if !ok {
|
||||
return errors.New("bad filename")
|
||||
path, err := m.joinDir(baseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var bo *backoff.Backoff
|
||||
logf := m.Logf
|
||||
@ -224,9 +224,9 @@ func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err e
|
||||
if m.DirectFileMode {
|
||||
return nil, 0, errors.New("opens not allowed in direct mode")
|
||||
}
|
||||
path, ok := m.joinDir(baseName)
|
||||
if !ok {
|
||||
return nil, 0, errors.New("bad filename")
|
||||
path, err := m.joinDir(baseName)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if fi, err := os.Stat(path + deletedSuffix); err == nil && fi.Mode().IsRegular() {
|
||||
tryDeleteAgain(path)
|
||||
|
@ -22,7 +22,7 @@ type incomingFileKey struct {
|
||||
}
|
||||
|
||||
type incomingFile struct {
|
||||
clock tstime.Clock
|
||||
clock tstime.DefaultClock
|
||||
|
||||
started time.Time
|
||||
size int64 // or -1 if unknown; never 0
|
||||
@ -62,6 +62,7 @@ func (f *incomingFile) Write(p []byte) (n int, err error) {
|
||||
// The baseName must be a base filename without any slashes.
|
||||
// The length is the expected length of content to read from r,
|
||||
// it may be negative to indicate that it is unknown.
|
||||
// It returns the length of the entire file.
|
||||
//
|
||||
// If there is a failure reading from r, then the partial file is not deleted
|
||||
// for some period of time. The [Manager.PartialFiles] and [Manager.HashPartialFile]
|
||||
@ -78,9 +79,9 @@ func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, len
|
||||
case distro.Get() == distro.Unraid && !m.DirectFileMode:
|
||||
return 0, ErrNotAccessible
|
||||
}
|
||||
dstPath, ok := m.joinDir(baseName)
|
||||
if !ok {
|
||||
return 0, ErrInvalidFileName
|
||||
dstPath, err := m.joinDir(baseName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
redactAndLogError := func(action string, err error) error {
|
||||
|
@ -45,7 +45,7 @@ func (id ClientID) partialSuffix() string {
|
||||
// Manager manages the state for receiving and managing taildropped files.
|
||||
type Manager struct {
|
||||
Logf logger.Logf
|
||||
Clock tstime.Clock
|
||||
Clock tstime.DefaultClock
|
||||
|
||||
// Dir is the directory to store received files.
|
||||
// This main either be the final location for the files
|
||||
@ -131,15 +131,15 @@ func validFilenameRune(r rune) bool {
|
||||
return unicode.IsPrint(r)
|
||||
}
|
||||
|
||||
func (m *Manager) joinDir(baseName string) (fullPath string, ok bool) {
|
||||
func (m *Manager) joinDir(baseName string) (fullPath string, err error) {
|
||||
if !utf8.ValidString(baseName) {
|
||||
return "", false
|
||||
return "", ErrInvalidFileName
|
||||
}
|
||||
if strings.TrimSpace(baseName) != baseName {
|
||||
return "", false
|
||||
return "", ErrInvalidFileName
|
||||
}
|
||||
if len(baseName) > 255 {
|
||||
return "", false
|
||||
return "", ErrInvalidFileName
|
||||
}
|
||||
// TODO: validate unicode normalization form too? Varies by platform.
|
||||
clean := path.Clean(baseName)
|
||||
@ -147,17 +147,17 @@ func (m *Manager) joinDir(baseName string) (fullPath string, ok bool) {
|
||||
clean == "." || clean == ".." ||
|
||||
strings.HasSuffix(clean, deletedSuffix) ||
|
||||
strings.HasSuffix(clean, partialSuffix) {
|
||||
return "", false
|
||||
return "", ErrInvalidFileName
|
||||
}
|
||||
for _, r := range baseName {
|
||||
if !validFilenameRune(r) {
|
||||
return "", false
|
||||
return "", ErrInvalidFileName
|
||||
}
|
||||
}
|
||||
if !filepath.IsLocal(baseName) {
|
||||
return "", false
|
||||
return "", ErrInvalidFileName
|
||||
}
|
||||
return filepath.Join(m.Dir, baseName), true
|
||||
return filepath.Join(m.Dir, baseName), nil
|
||||
}
|
||||
|
||||
// IncomingFiles returns a list of active incoming files.
|
||||
|
Loading…
Reference in New Issue
Block a user