diff --git a/hscontrol/app.go b/hscontrol/app.go index 9eb6b4cc..ce2fd1d8 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -137,7 +137,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, - nodeNotifier: notifier.NewNotifier(), + nodeNotifier: notifier.NewNotifier(cfg), mapSessions: make(map[types.NodeID]*mapSession), } diff --git a/hscontrol/notifier/metrics.go b/hscontrol/notifier/metrics.go index c461d379..1cc4df2b 100644 --- a/hscontrol/notifier/metrics.go +++ b/hscontrol/notifier/metrics.go @@ -18,7 +18,12 @@ var ( Namespace: prometheusNamespace, Name: "notifier_update_sent_total", Help: "total count of update sent on nodes channel", - }, []string{"status", "type"}) + }, []string{"status", "type", "trigger"}) + notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "notifier_update_received_total", + Help: "total count of updates received by notifier", + }, []string{"type", "trigger"}) notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: prometheusNamespace, Name: "notifier_open_channels_total", diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 4ad58723..74b6645e 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -3,7 +3,7 @@ package notifier import ( "context" "fmt" - "slices" + "sort" "strings" "sync" "time" @@ -11,19 +11,27 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" + "tailscale.com/util/set" ) type Notifier struct { l sync.RWMutex nodes map[types.NodeID]chan<- types.StateUpdate connected *xsync.MapOf[types.NodeID, bool] + b *batcher } -func NewNotifier() *Notifier { - return &Notifier{ +func NewNotifier(cfg *types.Config) *Notifier { + n := &Notifier{ nodes: make(map[types.NodeID]chan<- types.StateUpdate), connected: xsync.NewMapOf[types.NodeID, bool](), } + b := newBatcher(cfg.Tuning.BatchChangeDelay, n) + n.b = b + // TODO(kradalby): clean this up + go b.doWork() + return n } func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { @@ -108,13 +116,8 @@ func (n *Notifier) NotifyWithIgnore( update types.StateUpdate, ignoreNodeIDs ...types.NodeID, ) { - for nodeID := range n.nodes { - if slices.Contains(ignoreNodeIDs, nodeID) { - continue - } - - n.NotifyByNodeID(ctx, update, nodeID) - } + notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() + n.b.addOrPassthrough(update) } func (n *Notifier) NotifyByNodeID( @@ -139,10 +142,10 @@ func (n *Notifier) NotifyByNodeID( log.Error(). Err(ctx.Err()). Uint64("node.id", nodeID.Uint64()). - Any("origin", ctx.Value("origin")). - Any("origin-hostname", ctx.Value("hostname")). + Any("origin", types.NotifyOriginKey.Value(ctx)). + Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)). Msgf("update not sent, context cancelled") - notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc() + notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() return case c <- update: @@ -151,11 +154,23 @@ func (n *Notifier) NotifyByNodeID( Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update successfully sent on chan") - notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc() + notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() } } } +func (n *Notifier) sendAll(update types.StateUpdate) { + start := time.Now() + n.l.RLock() + defer n.l.RUnlock() + notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds()) + + for _, c := range n.nodes { + c <- update + notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc() + } +} + func (n *Notifier) String() string { n.l.RLock() defer n.l.RUnlock() @@ -177,3 +192,166 @@ func (n *Notifier) String() string { return b.String() } + +type batcher struct { + tick *time.Ticker + + mu sync.Mutex + + cancelCh chan struct{} + + changedNodeIDs set.Slice[types.NodeID] + nodesChanged bool + patches map[types.NodeID]tailcfg.PeerChange + patchesChanged bool + + n *Notifier +} + +func newBatcher(batchTime time.Duration, n *Notifier) *batcher { + return &batcher{ + tick: time.NewTicker(batchTime), + cancelCh: make(chan struct{}), + patches: make(map[types.NodeID]tailcfg.PeerChange), + n: n, + } + +} + +func (b *batcher) close() { + b.cancelCh <- struct{}{} +} + +// addOrPassthrough adds the update to the batcher, if it is not a +// type that is currently batched, it will be sent immediately. +func (b *batcher) addOrPassthrough(update types.StateUpdate) { + b.mu.Lock() + defer b.mu.Unlock() + + switch update.Type { + case types.StatePeerChanged: + b.changedNodeIDs.Add(update.ChangeNodes...) + b.nodesChanged = true + + case types.StatePeerChangedPatch: + for _, newPatch := range update.ChangePatches { + if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok { + overwritePatch(&curr, newPatch) + b.patches[types.NodeID(newPatch.NodeID)] = curr + } else { + b.patches[types.NodeID(newPatch.NodeID)] = *newPatch + } + } + b.patchesChanged = true + + default: + b.n.sendAll(update) + } +} + +// flush sends all the accumulated patches to all +// nodes in the notifier. +func (b *batcher) flush() { + b.mu.Lock() + defer b.mu.Unlock() + + if b.nodesChanged || b.patchesChanged { + var patches []*tailcfg.PeerChange + // If a node is getting a full update from a change + // node update, then the patch can be dropped. + for nodeID, patch := range b.patches { + if b.changedNodeIDs.Contains(nodeID) { + delete(b.patches, nodeID) + } else { + patches = append(patches, &patch) + } + } + + changedNodes := b.changedNodeIDs.Slice().AsSlice() + sort.Slice(changedNodes, func(i, j int) bool { + return changedNodes[i] < changedNodes[j] + }) + + if b.changedNodeIDs.Slice().Len() > 0 { + update := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: changedNodes, + } + + b.n.sendAll(update) + } + + if len(patches) > 0 { + patchUpdate := types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: patches, + } + + b.n.sendAll(patchUpdate) + } + + b.changedNodeIDs = set.Slice[types.NodeID]{} + b.nodesChanged = false + b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches)) + b.patchesChanged = false + } +} + +func (b *batcher) doWork() { + for { + select { + case <-b.cancelCh: + return + case <-b.tick.C: + b.flush() + } + } +} + +// overwritePatch takes the current patch and a newer patch +// and override any field that has changed +func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) { + if newPatch.DERPRegion != 0 { + currPatch.DERPRegion = newPatch.DERPRegion + } + + if newPatch.Cap != 0 { + currPatch.Cap = newPatch.Cap + } + + if newPatch.CapMap != nil { + currPatch.CapMap = newPatch.CapMap + } + + if newPatch.Endpoints != nil { + currPatch.Endpoints = newPatch.Endpoints + } + + if newPatch.Key != nil { + currPatch.Key = newPatch.Key + } + + if newPatch.KeySignature != nil { + currPatch.KeySignature = newPatch.KeySignature + } + + if newPatch.DiscoKey != nil { + currPatch.DiscoKey = newPatch.DiscoKey + } + + if newPatch.Online != nil { + currPatch.Online = newPatch.Online + } + + if newPatch.LastSeen != nil { + currPatch.LastSeen = newPatch.LastSeen + } + + if newPatch.KeyExpiry != nil { + currPatch.KeyExpiry = newPatch.KeyExpiry + } + + if newPatch.Capabilities != nil { + currPatch.Capabilities = newPatch.Capabilities + } +} diff --git a/hscontrol/notifier/notifier_test.go b/hscontrol/notifier/notifier_test.go new file mode 100644 index 00000000..4d61f134 --- /dev/null +++ b/hscontrol/notifier/notifier_test.go @@ -0,0 +1,249 @@ +package notifier + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "tailscale.com/tailcfg" +) + +func TestBatcher(t *testing.T) { + tests := []struct { + name string + updates []types.StateUpdate + want []types.StateUpdate + }{ + { + name: "full-passthrough", + updates: []types.StateUpdate{ + { + Type: types.StateFullUpdate, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StateFullUpdate, + }, + }, + }, + { + name: "derp-passthrough", + updates: []types.StateUpdate{ + { + Type: types.StateDERPUpdated, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StateDERPUpdated, + }, + }, + }, + { + name: "single-node-update", + updates: []types.StateUpdate{ + { + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{ + 2, + }, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{ + 2, + }, + }, + }, + }, + { + name: "merge-node-update", + updates: []types.StateUpdate{ + { + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{ + 2, 4, + }, + }, + { + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{ + 2, 3, + }, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{ + 2, 3, 4, + }, + }, + }, + }, + { + name: "single-patch-update", + updates: []types.StateUpdate{ + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 2, + DERPRegion: 5, + }, + }, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 2, + DERPRegion: 5, + }, + }, + }, + }, + }, + { + name: "merge-patch-to-same-node-update", + updates: []types.StateUpdate{ + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 2, + DERPRegion: 5, + }, + }, + }, + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 2, + DERPRegion: 6, + }, + }, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 2, + DERPRegion: 6, + }, + }, + }, + }, + }, + { + name: "merge-patch-to-multiple-node-update", + updates: []types.StateUpdate{ + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 3, + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("1.1.1.1:9090"), + }, + }, + }, + }, + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 3, + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("1.1.1.1:9090"), + netip.MustParseAddrPort("2.2.2.2:8080"), + }, + }, + }, + }, + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 4, + DERPRegion: 6, + }, + }, + }, + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 4, + Cap: tailcfg.CapabilityVersion(54), + }, + }, + }, + }, + want: []types.StateUpdate{ + { + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: 3, + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("1.1.1.1:9090"), + netip.MustParseAddrPort("2.2.2.2:8080"), + }, + }, + { + NodeID: 4, + DERPRegion: 6, + Cap: tailcfg.CapabilityVersion(54), + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NewNotifier(&types.Config{ + Tuning: types.Tuning{ + // We will call flush manually for the tests, + // so do not run the worker. + BatchChangeDelay: time.Hour, + }, + }) + + ch := make(chan types.StateUpdate, 30) + defer close(ch) + n.AddNode(1, ch) + defer n.RemoveNode(1) + + for _, u := range tt.updates { + n.NotifyAll(context.Background(), u) + } + + n.b.flush() + + var got []types.StateUpdate + for len(ch) > 0 { + out := <-ch + got = append(got, out) + } + + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("batcher() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index b903f122..e3137cc6 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -66,10 +66,16 @@ func (h *Headscale) newMapSession( ) *mapSession { warnf, infof, tracef, errf := logPollFunc(req, node) - // Use a buffered channel in case a node is not fully ready - // to receive a message to make sure we dont block the entire - // notifier. - updateChan := make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) + var updateChan chan types.StateUpdate + if req.Stream { + // Use a buffered channel in case a node is not fully ready + // to receive a message to make sure we dont block the entire + // notifier. + updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) + updateChan <- types.StateUpdate{ + Type: types.StateFullUpdate, + } + } return &mapSession{ h: h, @@ -218,33 +224,26 @@ func (m *mapSession) serve() { ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) defer cancel() - // TODO(kradalby): Make this available through a tuning envvar - wait := time.Second - - // Add a circuit breaker, if the loop is not interrupted - // inbetween listening for the channels, some updates - // might get stale and stucked in the "changed" map - // defined below. - blockBreaker := time.NewTicker(wait) - - // true means changed, false means removed - var changed map[types.NodeID]bool - var patches []*tailcfg.PeerChange - var derp bool - - // Set full to true to immediatly send a full mapresponse - full := true - prev := time.Now() - lastMessage := "" - // Loop through updates and continuously send them to the // client. for { - // If a full update has been requested or there are patches, then send it immediately - // otherwise wait for the "batching" of changes or patches - if full || patches != nil || (changed != nil && time.Since(prev) > wait) { + // consume channels with update, keep alives or "batch" blocking signals + select { + case <-m.cancelCh: + m.tracef("poll cancelled received") + return + case <-ctx.Done(): + m.tracef("poll context done") + return + + // Consume all updates sent to node + case update := <-m.ch: + m.tracef("received stream update: %s %s", update.Type.String(), update.Message) + mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc() + var data []byte var err error + var lastMessage string // Ensure the node object is updated, for example, there // might have been a hostinfo update in a sidechannel @@ -256,62 +255,43 @@ func (m *mapSession) serve() { return } - // If there are patches _and_ fully changed nodes, filter the - // patches and remove all patches that are present for the full - // changes updates. This allows us to send them as part of the - // PeerChange update, but only for nodes that are not fully changed. - // The fully changed nodes will be updated from the database and - // have all the updates needed. - // This means that the patches left are for nodes that has no - // updates that requires a full update. - // Patches are not suppose to be mixed in, but can be. - // - // From tailcfg docs: - // These are applied after Peers* above, but in practice the - // control server should only send these on their own, without - // - // Currently, there is no effort to merge patch updates, they - // are all sent, and the client will apply them in order. - // TODO(kradalby): Merge Patches for the same IDs to send less - // data and give the client less work. - if patches != nil && changed != nil { - var filteredPatches []*tailcfg.PeerChange - - for _, patch := range patches { - if _, ok := changed[types.NodeID(patch.NodeID)]; !ok { - filteredPatches = append(filteredPatches, patch) - } - } - - patches = filteredPatches - } - updateType := "full" - // When deciding what update to send, the following is considered, - // Full is a superset of all updates, when a full update is requested, - // send only that and move on, all other updates will be present in - // a full map response. - // - // If a map of changed nodes exists, prefer sending that as it will - // contain all the updates for the node, including patches, as it - // is fetched freshly from the database when building the response. - // - // If there is full changes registered, but we have patches for individual - // nodes, send them. - // - // Finally, if a DERP map is the only request, send that alone. - if full { + switch update.Type { + case types.StateFullUpdate: m.tracef("Sending Full MapResponse") data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) - } else if changed != nil { + case types.StatePeerChanged: + changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) + + for _, nodeID := range update.ChangeNodes { + changed[nodeID] = true + } + + lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) updateType = "change" - } else if patches != nil { + + case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) updateType = "patch" - } else if derp { + case types.StatePeerRemoved: + changed := make(map[types.NodeID]bool, len(update.Removed)) + + for _, nodeID := range update.Removed { + changed[nodeID] = false + } + m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + updateType = "remove" + case types.StateSelfUpdate: + lastMessage = update.Message + m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) + // create the map so an empty (self) update is sent + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + updateType = "remove" + case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap) updateType = "derp" @@ -348,68 +328,6 @@ func (m *mapSession) serve() { m.tracef("update sent") } - // reset - changed = nil - patches = nil - lastMessage = "" - full = false - derp = false - prev = time.Now() - } - - // consume channels with update, keep alives or "batch" blocking signals - select { - case <-m.cancelCh: - m.tracef("poll cancelled received") - return - case <-ctx.Done(): - m.tracef("poll context done") - return - - // Avoid infinite block that would potentially leave - // some updates in the changed map. - case <-blockBreaker.C: - continue - - // Consume all updates sent to node - case update := <-m.ch: - m.tracef("received stream update: %s %s", update.Type.String(), update.Message) - mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc() - - switch update.Type { - case types.StateFullUpdate: - full = true - case types.StatePeerChanged: - if changed == nil { - changed = make(map[types.NodeID]bool) - } - - for _, nodeID := range update.ChangeNodes { - changed[nodeID] = true - } - - lastMessage = update.Message - case types.StatePeerChangedPatch: - patches = append(patches, update.ChangePatches...) - case types.StatePeerRemoved: - if changed == nil { - changed = make(map[types.NodeID]bool) - } - - for _, nodeID := range update.Removed { - changed[nodeID] = false - } - case types.StateSelfUpdate: - // create the map so an empty (self) update is sent - if changed == nil { - changed = make(map[types.NodeID]bool) - } - - lastMessage = update.Message - case types.StateDERPUpdated: - derp = true - } - case <-m.keepAliveTicker.C: data, err := m.mapper.KeepAliveResponse(m.req, m.node) if err != nil { diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 6d63f301..35f5e5e4 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -10,6 +10,7 @@ import ( "time" "tailscale.com/tailcfg" + "tailscale.com/util/ctxkey" ) const ( @@ -183,10 +184,14 @@ func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } +var ( + NotifyOriginKey = ctxkey.New("notify.origin", "") + NotifyHostnameKey = ctxkey.New("notify.hostname", "") +) + func NotifyCtx(ctx context.Context, origin, hostname string) context.Context { - ctx2, _ := context.WithTimeout( - context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin), - 3*time.Second, - ) + ctx2, _ := context.WithTimeout(ctx, 3*time.Second) + ctx2 = NotifyOriginKey.WithValue(ctx2, origin) + ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname) return ctx2 } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 0483213b..a118b6fc 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -11,6 +11,7 @@ import ( "encoding/pem" "errors" "fmt" + "io" "log" "math/big" "net" @@ -396,6 +397,14 @@ func (t *HeadscaleInContainer) Shutdown() error { ) } + err = t.SaveMetrics("/tmp/control/metrics.txt") + if err != nil { + log.Printf( + "Failed to metrics from control: %s", + err, + ) + } + // Send a interrupt signal to the "headscale" process inside the container // allowing it to shut down gracefully and flush the profile to disk. // The container will live for a bit longer due to the sleep at the end. @@ -448,6 +457,25 @@ func (t *HeadscaleInContainer) SaveLog(path string) error { return dockertestutil.SaveLog(t.pool, t.container, path) } +func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { + resp, err := http.Get(fmt.Sprintf("http://%s:9090/metrics", t.hostname)) + if err != nil { + return fmt.Errorf("getting metrics: %w", err) + } + defer resp.Body.Close() + out, err := os.Create(savePath) + if err != nil { + return fmt.Errorf("creating file for metrics: %w", err) + } + defer out.Close() + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("copy response to file: %w", err) + } + + return nil +} + func (t *HeadscaleInContainer) SaveProfile(savePath string) error { tarFile, err := t.FetchPath("/tmp/profile") if err != nil { diff --git a/integration/route_test.go b/integration/route_test.go index 15ea22b1..48b6c07f 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -252,7 +252,7 @@ func TestHASubnetRouterFailover(t *testing.T) { scenario, err := NewScenario(dockertestMaxWait()) assertNoErrf(t, "failed to create scenario: %s", err) - // defer scenario.Shutdown() + defer scenario.Shutdown() spec := map[string]int{ user: 3,