From a008d97105ec5846b643e82f65154d7f32f126c6 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Fri, 28 Nov 2025 10:30:30 +0000 Subject: [PATCH] concurrent writers Change-Id: I0776d7afec7158829c6e350d6fdd210cd61a46c4 Signed-off-by: Tom Proctor --- cmd/cigocacher/cigocacher.go | 8 --- cmd/cigocacher/disk.go | 58 +++---------------- cmd/cigocacher/disk_notwindows.go | 46 +++++++++++++++ cmd/cigocacher/disk_windows.go | 96 +++++++++++++++++++++++++++++++ 4 files changed, 149 insertions(+), 59 deletions(-) create mode 100644 cmd/cigocacher/disk_notwindows.go create mode 100644 cmd/cigocacher/disk_windows.go diff --git a/cmd/cigocacher/cigocacher.go b/cmd/cigocacher/cigocacher.go index 8d75dcbe2..c323224b0 100644 --- a/cmd/cigocacher/cigocacher.go +++ b/cmd/cigocacher/cigocacher.go @@ -171,10 +171,6 @@ func (c *cigocacher) get(ctx context.Context, actionID string) (outputID, diskPa defer res.Body.Close() - // TODO(tomhjp): make sure we timeout if cigocached disappears, but for some - // reason, this seemed to tank network performance. - // ctx, cancel := context.WithTimeout(ctx, httpTimeout(res.ContentLength)) - // defer cancel() diskPath, err = c.disk.Put(ctx, actionID, outputID, res.ContentLength, res.Body) if err != nil { return "", "", fmt.Errorf("error filling disk cache from HTTP: %w", err) @@ -213,10 +209,6 @@ func (c *cigocacher) put(ctx context.Context, actionID, outputID string, size in } httpErrCh := make(chan error) go func() { - // TODO(tomhjp): make sure we timeout if cigocached disappears, but for some - // reason, this seemed to tank network performance. - // ctx, cancel := context.WithTimeout(ctx, httpTimeout(size)) - // defer cancel() t0HTTP := time.Now() defer func() { c.putHTTPNanos.Add(time.Since(t0HTTP).Nanoseconds()) diff --git a/cmd/cigocacher/disk.go b/cmd/cigocacher/disk.go index 9537e8126..39b756ab6 100644 --- a/cmd/cigocacher/disk.go +++ b/cmd/cigocacher/disk.go @@ -4,7 +4,6 @@ package main import ( - "bytes" "context" "encoding/json" "errors" @@ -13,7 +12,6 @@ import ( "log" "os" "path/filepath" - "runtime" "time" ) @@ -119,21 +117,12 @@ func (dc *DiskCache) Put(ctx context.Context, actionID, outputID string, size in return "", fmt.Errorf("failed to create output directory: %w", err) } - // Special case empty files; they're both common and easier to do race-free. - if size == 0 { - zf, err := os.OpenFile(outputFile, os.O_CREATE|os.O_RDWR, 0644) - if err != nil { - return "", err - } - zf.Close() - } else { - wrote, err := writeAtomic(outputFile, body) - if err != nil { - return "", err - } - if wrote != size { - return "", fmt.Errorf("wrote %d bytes, expected %d", wrote, size) - } + wrote, err := dc.writeOutputFile(body, size, outputID) + if err != nil { + return "", err + } + if wrote != size { + return "", fmt.Errorf("wrote %d bytes, expected %d", wrote, size) } ij, err := json.Marshal(indexEntry{ @@ -145,41 +134,8 @@ func (dc *DiskCache) Put(ctx context.Context, actionID, outputID string, size in if err != nil { return "", err } - if _, err := writeAtomic(actionFile, bytes.NewReader(ij)); err != nil { + if err := dc.writeActionFile(ij, actionID); err != nil { return "", fmt.Errorf("atomic write failed: %w", err) } return outputFile, nil } - -func writeAtomic(dest string, r io.Reader) (int64, error) { - tf, err := os.CreateTemp(filepath.Dir(dest), filepath.Base(dest)+".*") - if err != nil { - return 0, err - } - size, err := io.Copy(tf, r) - if err != nil { - tf.Close() - os.Remove(tf.Name()) - return 0, err - } - if err := tf.Close(); err != nil { - os.Remove(tf.Name()) - return 0, err - } - if err := os.Rename(tf.Name(), dest); err != nil { - os.Remove(tf.Name()) - if runtime.GOOS == "windows" { - if st, statErr := os.Stat(dest); statErr == nil && st.Size() == size { - log.Printf("DEBUG: WE DID THE WINTHING") - return size, nil - } else { - log.Printf("DEBUG: %v", statErr) - if st != nil { - log.Printf("DEBUG: size=%d, wanted %d", st.Size(), size) - } - } - } - return 0, err - } - return size, nil -} diff --git a/cmd/cigocacher/disk_notwindows.go b/cmd/cigocacher/disk_notwindows.go new file mode 100644 index 000000000..a90fb5950 --- /dev/null +++ b/cmd/cigocacher/disk_notwindows.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package main + +import ( + "bytes" + "io" + "os" + "path/filepath" +) + +func (dc *DiskCache) writeActionFile(b []byte, actionID string) error { + dest := dc.ActionFilename(actionID) + _, err := writeAtomic(dest, bytes.NewReader(b)) + return err +} + +func (dc *DiskCache) writeOutputFile(r io.Reader, _ int64, outputID string) (int64, error) { + dest := dc.OutputFilename(outputID) + return writeAtomic(dest, r) +} + +func writeAtomic(dest string, r io.Reader) (int64, error) { + tf, err := os.CreateTemp(filepath.Dir(dest), filepath.Base(dest)+".*") + if err != nil { + return 0, err + } + size, err := io.Copy(tf, r) + if err != nil { + tf.Close() + os.Remove(tf.Name()) + return 0, err + } + if err := tf.Close(); err != nil { + os.Remove(tf.Name()) + return 0, err + } + if err := os.Rename(tf.Name(), dest); err != nil { + os.Remove(tf.Name()) + return 0, err + } + return size, nil +} diff --git a/cmd/cigocacher/disk_windows.go b/cmd/cigocacher/disk_windows.go new file mode 100644 index 000000000..6601c997e --- /dev/null +++ b/cmd/cigocacher/disk_windows.go @@ -0,0 +1,96 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + "os" +) + +func (dc *DiskCache) writeActionFile(b []byte, actionID string) (retErr error) { + dest := dc.ActionFilename(actionID) + f, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE, 0o666) + if err != nil { + return err + } + defer func() { + cerr := f.Close() + if retErr != nil || cerr != nil { + retErr = errors.Join(retErr, cerr, os.Remove(dest)) + } + }() + + _, err = f.Write(b) + if err != nil { + return err + } + + // Truncate the file only *after* writing it. + // (This should be a no-op, but truncate just in case of previous corruption.) + // + // This differs from os.WriteFile, which truncates to 0 *before* writing + // via os.O_TRUNC. Truncating only after writing ensures that a second write + // of the same content to the same file is idempotent, and does not - even + // temporarily! - undo the effect of the first write. + return f.Truncate(int64(len(b))) +} + +// writeOutputFile writes content to be cached to disk. The outputID is the +// sha256 hash of the content, and each file should only be written ~once, +// assuming no sha256 hash collisions. It may be written multiple times if +// concurrent processes are both populating the same output. The file is opened +// with FILE_SHARE_READ|FILE_SHARE_WRITE, which means both processes can write +// the same contents concurrently without conflict. +// +// It makes a best effort to clean up if anything goes wrong, but the file may +// be left in an inconsistent state in the event of disk-related errors such as +// another process taking file locks, or power loss etc. +func (dc *DiskCache) writeOutputFile(r io.Reader, size int64, outputID string) (_ int64, retErr error) { + dest := dc.OutputFilename(outputID) + info, err := os.Stat(dest) + if err == nil && info.Size() == size { + // Already exists, check the hash. + if f, err := os.Open(dest); err == nil { + h := sha256.New() + io.Copy(h, f) + f.Close() + if fmt.Sprintf("%x", h.Sum(nil)) == outputID { + // Still drain the reader to ensure associated resources are released. + return io.Copy(io.Discard, r) + } + } + } + + // Didn't successfully find the pre-existing file, write it. + mode := os.O_WRONLY | os.O_CREATE + if err == nil && info.Size() > size { + mode |= os.O_TRUNC // Should never happen, but self-heal. + } + f, err := os.OpenFile(dest, mode, 0644) + if err != nil { + return 0, fmt.Errorf("failed to open output file %q: %w", dest, err) + } + defer func() { + cerr := f.Close() + if retErr != nil || cerr != nil { + retErr = errors.Join(retErr, cerr, os.Remove(dest)) + } + }() + + // Copy file to f, but also into h to double-check hash. + h := sha256.New() + w := io.MultiWriter(f, h) + n, err := io.Copy(w, r) + if err != nil { + return 0, err + } + if fmt.Sprintf("%x", h.Sum(nil)) != outputID { + return 0, errors.New("file content changed underfoot") + } + + return n, nil +}