mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 18:20:07 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			154 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			154 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| package taildrop
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/sha256"
 | |
| 	"encoding/hex"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"io/fs"
 | |
| 	"os"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	blockSize     = int64(64 << 10)
 | |
| 	hashAlgorithm = "sha256"
 | |
| )
 | |
| 
 | |
| // BlockChecksum represents the checksum for a single block.
 | |
| type BlockChecksum struct {
 | |
| 	Checksum  Checksum `json:"checksum"`
 | |
| 	Algorithm string   `json:"algo"` // always "sha256" for now
 | |
| 	Size      int64    `json:"size"` // always (64<<10) for now
 | |
| }
 | |
| 
 | |
| // Checksum is an opaque checksum that is comparable.
 | |
| type Checksum struct{ cs [sha256.Size]byte }
 | |
| 
 | |
| func hash(b []byte) Checksum {
 | |
| 	return Checksum{sha256.Sum256(b)}
 | |
| }
 | |
| func (cs Checksum) String() string {
 | |
| 	return hex.EncodeToString(cs.cs[:])
 | |
| }
 | |
| func (cs Checksum) AppendText(b []byte) ([]byte, error) {
 | |
| 	return hex.AppendEncode(b, cs.cs[:]), nil
 | |
| }
 | |
| func (cs Checksum) MarshalText() ([]byte, error) {
 | |
| 	return hex.AppendEncode(nil, cs.cs[:]), nil
 | |
| }
 | |
| func (cs *Checksum) UnmarshalText(b []byte) error {
 | |
| 	if len(b) != 2*len(cs.cs) {
 | |
| 		return fmt.Errorf("invalid hex length: %d", len(b))
 | |
| 	}
 | |
| 	_, err := hex.Decode(cs.cs[:], b)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // PartialFiles returns a list of partial files in [Handler.Dir]
 | |
| // that were sent (or is actively being sent) by the provided id.
 | |
| func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) {
 | |
| 	if m == nil || m.opts.Dir == "" {
 | |
| 		return nil, ErrNoTaildrop
 | |
| 	}
 | |
| 
 | |
| 	suffix := id.partialSuffix()
 | |
| 	if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool {
 | |
| 		if name := de.Name(); strings.HasSuffix(name, suffix) {
 | |
| 			ret = append(ret, name)
 | |
| 		}
 | |
| 		return true
 | |
| 	}); err != nil {
 | |
| 		return ret, redactError(err)
 | |
| 	}
 | |
| 	return ret, nil
 | |
| }
 | |
| 
 | |
| // HashPartialFile returns a function that hashes the next block in the file,
 | |
| // starting from the beginning of the file.
 | |
| // It returns (BlockChecksum{}, io.EOF) when the stream is complete.
 | |
| // It is the caller's responsibility to call close.
 | |
| func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (BlockChecksum, error), close func() error, err error) {
 | |
| 	if m == nil || m.opts.Dir == "" {
 | |
| 		return nil, nil, ErrNoTaildrop
 | |
| 	}
 | |
| 	noopNext := func() (BlockChecksum, error) { return BlockChecksum{}, io.EOF }
 | |
| 	noopClose := func() error { return nil }
 | |
| 
 | |
| 	dstFile, err := joinDir(m.opts.Dir, baseName)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 	f, err := os.Open(dstFile + id.partialSuffix())
 | |
| 	if err != nil {
 | |
| 		if os.IsNotExist(err) {
 | |
| 			return noopNext, noopClose, nil
 | |
| 		}
 | |
| 		return nil, nil, redactError(err)
 | |
| 	}
 | |
| 
 | |
| 	b := make([]byte, blockSize) // TODO: Pool this?
 | |
| 	next = func() (BlockChecksum, error) {
 | |
| 		switch n, err := io.ReadFull(f, b); {
 | |
| 		case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF:
 | |
| 			return BlockChecksum{}, redactError(err)
 | |
| 		case n == 0:
 | |
| 			return BlockChecksum{}, io.EOF
 | |
| 		default:
 | |
| 			return BlockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil
 | |
| 		}
 | |
| 	}
 | |
| 	close = f.Close
 | |
| 	return next, close, nil
 | |
| }
 | |
| 
 | |
| // ResumeReader reads and discards the leading content of r
 | |
| // that matches the content based on the checksums that exist.
 | |
| // It returns the number of bytes consumed,
 | |
| // and returns an [io.Reader] representing the remaining content.
 | |
| func ResumeReader(r io.Reader, hashNext func() (BlockChecksum, error)) (int64, io.Reader, error) {
 | |
| 	if hashNext == nil {
 | |
| 		return 0, r, nil
 | |
| 	}
 | |
| 
 | |
| 	var offset int64
 | |
| 	b := make([]byte, 0, blockSize)
 | |
| 	for {
 | |
| 		// Obtain the next block checksum from the remote peer.
 | |
| 		cs, err := hashNext()
 | |
| 		switch {
 | |
| 		case err == io.EOF:
 | |
| 			return offset, io.MultiReader(bytes.NewReader(b), r), nil
 | |
| 		case err != nil:
 | |
| 			return offset, io.MultiReader(bytes.NewReader(b), r), err
 | |
| 		case cs.Algorithm != hashAlgorithm || cs.Size < 0 || cs.Size > blockSize:
 | |
| 			return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm")
 | |
| 		}
 | |
| 
 | |
| 		// Read the contents of the next block.
 | |
| 		n, err := io.ReadFull(r, b[:cs.Size])
 | |
| 		b = b[:n]
 | |
| 		if err == io.EOF || err == io.ErrUnexpectedEOF {
 | |
| 			err = nil
 | |
| 		}
 | |
| 		if len(b) == 0 || err != nil {
 | |
| 			// This should not occur in practice.
 | |
| 			// It implies that an error occurred reading r,
 | |
| 			// or that the partial file on the remote side is fully complete.
 | |
| 			return offset, io.MultiReader(bytes.NewReader(b), r), err
 | |
| 		}
 | |
| 
 | |
| 		// Compare the local and remote block checksums.
 | |
| 		// If it mismatches, then resume from this point.
 | |
| 		if cs.Checksum != hash(b) {
 | |
| 			return offset, io.MultiReader(bytes.NewReader(b), r), nil
 | |
| 		}
 | |
| 		offset += int64(len(b))
 | |
| 		b = b[:0]
 | |
| 	}
 | |
| }
 | 
