From 432e975a7f7b2e70f1f2ceb67aaf1d43e2902912 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 9 Aug 2023 22:56:21 +0200 Subject: [PATCH] move MapResponse peer logic into function and reuse Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 2 + hscontrol/db/routes.go | 6 +- hscontrol/mapper/mapper.go | 317 ++++++++++++++++---------------- hscontrol/mapper/mapper_test.go | 20 +- hscontrol/poll.go | 5 + integration/general_test.go | 5 +- integration/tsic/tsic.go | 11 +- 7 files changed, 193 insertions(+), 173 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index bff9b98c..57fb5848 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -92,6 +92,8 @@ type Headscale struct { shutdownChan chan struct{} pollNetMapStreamWG sync.WaitGroup + + pollStreamOpenMu sync.Mutex } func NewHeadscale(cfg *types.Config) (*Headscale, error) { diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 834f53ec..af6b744f 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -340,6 +340,8 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { continue } + machine := &route.Machine + if !route.IsPrimary { _, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix)) if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { @@ -355,7 +357,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { return err } - changedMachines = append(changedMachines, &route.Machine) + changedMachines = append(changedMachines, machine) continue } @@ -429,7 +431,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { return err } - changedMachines = append(changedMachines, &route.Machine) + changedMachines = append(changedMachines, machine) } } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index f5a2ddc6..67e078d4 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -38,6 +38,16 @@ const ( var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH") +// TODO: Optimise +// As this work continues, the idea is that there will be one Mapper instance +// per node, attached to the open stream between the control and client. +// This means that this can hold a state per machine and we can use that to +// improve the mapresponses sent. +// We could: +// - Keep information about the previous mapresponse so we can send a diff +// - Store hashes +// - Create a "minifier" that removes info not needed for the node + type Mapper struct { privateKey2019 *key.MachinePrivate isNoise bool @@ -102,105 +112,6 @@ func (m *Mapper) String() string { return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created) } -// TODO: Optimise -// As this work continues, the idea is that there will be one Mapper instance -// per node, attached to the open stream between the control and client. -// This means that this can hold a state per machine and we can use that to -// improve the mapresponses sent. -// We could: -// - Keep information about the previous mapresponse so we can send a diff -// - Store hashes -// - Create a "minifier" that removes info not needed for the node - -// fullMapResponse is the internal function for generating a MapResponse -// for a machine. -func fullMapResponse( - pol *policy.ACLPolicy, - machine *types.Machine, - peers types.Machines, - - baseDomain string, - dnsCfg *tailcfg.DNSConfig, - derpMap *tailcfg.DERPMap, - logtail bool, - randomClientPort bool, -) (*tailcfg.MapResponse, error) { - tailnode, err := tailNode(machine, pol, dnsCfg, baseDomain) - if err != nil { - return nil, err - } - - now := time.Now() - - resp := tailcfg.MapResponse{ - Node: tailnode, - - DERPMap: derpMap, - - Domain: baseDomain, - - // Do not instruct clients to collect services we do not - // support or do anything with them - CollectServices: "false", - - ControlTime: &now, - KeepAlive: false, - OnlineChange: db.OnlineMachineMap(peers), - - Debug: &tailcfg.Debug{ - DisableLogTail: !logtail, - RandomizeClientPort: randomClientPort, - }, - } - - if peers != nil || len(peers) > 0 { - rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( - pol, - machine, - peers, - ) - if err != nil { - return nil, err - } - - // Filter out peers that have expired. - peers = filterExpiredAndNotReady(peers) - - // If there are filter rules present, see if there are any machines that cannot - // access eachother at all and remove them from the peers. - if len(rules) > 0 { - peers = policy.FilterMachinesByACL(machine, peers, rules) - } - - profiles := generateUserProfiles(machine, peers, baseDomain) - - dnsConfig := generateDNSConfig( - dnsCfg, - baseDomain, - machine, - peers, - ) - - tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain) - if err != nil { - return nil, err - } - - // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID - }) - - resp.Peers = tailPeers - resp.DNSConfig = dnsConfig - resp.PacketFilter = policy.ReduceFilterRules(machine, rules) - resp.UserProfiles = profiles - resp.SSHPolicy = sshPolicy - } - - return &resp, nil -} - func generateUserProfiles( machine *types.Machine, peers types.Machines, @@ -294,6 +205,38 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) { } } +// fullMapResponse creates a complete MapResponse for a node. +// It is a separate function to make testing easier. +func (m *Mapper) fullMapResponse( + machine *types.Machine, + pol *policy.ACLPolicy, +) (*tailcfg.MapResponse, error) { + peers := machineMapToList(m.peers) + + resp, err := m.baseWithConfigMapResponse(machine, pol) + if err != nil { + return nil, err + } + + // TODO(kradalby): Move this into appendPeerChanges? + resp.OnlineChange = db.OnlineMachineMap(peers) + + err = appendPeerChanges( + resp, + pol, + machine, + peers, + peers, + m.baseDomain, + m.dnsCfg, + ) + if err != nil { + return nil, err + } + + return resp, nil +} + // FullMapResponse returns a MapResponse for the given machine. func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, @@ -303,25 +246,16 @@ func (m *Mapper) FullMapResponse( m.mu.Lock() defer m.mu.Unlock() - mapResponse, err := fullMapResponse( - pol, - machine, - machineMapToList(m.peers), - m.baseDomain, - m.dnsCfg, - m.derpMap, - m.logtail, - m.randomClientPort, - ) + resp, err := m.fullMapResponse(machine, pol) if err != nil { return nil, err } if m.isNoise { - return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) } - return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) } // LiteMapResponse returns a MapResponse for the given machine. @@ -332,32 +266,23 @@ func (m *Mapper) LiteMapResponse( machine *types.Machine, pol *policy.ACLPolicy, ) ([]byte, error) { - mapResponse, err := fullMapResponse( - pol, - machine, - nil, - m.baseDomain, - m.dnsCfg, - m.derpMap, - m.logtail, - m.randomClientPort, - ) + resp, err := m.baseWithConfigMapResponse(machine, pol) if err != nil { return nil, err } if m.isNoise { - return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) } - return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) } func (m *Mapper) KeepAliveResponse( mapRequest tailcfg.MapRequest, machine *types.Machine, ) ([]byte, error) { - resp := m.baseMapResponse(machine) + resp := m.baseMapResponse() resp.KeepAlive = true return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) @@ -368,7 +293,7 @@ func (m *Mapper) DERPMapResponse( machine *types.Machine, derpMap tailcfg.DERPMap, ) ([]byte, error) { - resp := m.baseMapResponse(machine) + resp := m.baseMapResponse() resp.DERPMap = &derpMap return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) @@ -383,7 +308,6 @@ func (m *Mapper) PeerChangedResponse( m.mu.Lock() defer m.mu.Unlock() - var err error lastSeen := make(map[tailcfg.NodeID]bool) // Update our internal map. @@ -394,37 +318,21 @@ func (m *Mapper) PeerChangedResponse( lastSeen[tailcfg.NodeID(machine.ID)] = true } - rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( + resp := m.baseMapResponse() + + err := appendPeerChanges( + &resp, pol, machine, machineMapToList(m.peers), + changed, + m.baseDomain, + m.dnsCfg, ) if err != nil { return nil, err } - changed = filterExpiredAndNotReady(changed) - - // If there are filter rules present, see if there are any machines that cannot - // access eachother at all and remove them from the changed. - if len(rules) > 0 { - changed = policy.FilterMachinesByACL(machine, changed, rules) - } - - tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain) - if err != nil { - return nil, err - } - - // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID - }) - - resp := m.baseMapResponse(machine) - resp.PeersChanged = tailPeers - resp.PacketFilter = policy.ReduceFilterRules(machine, rules) - resp.SSHPolicy = sshPolicy // resp.PeerSeenChange = lastSeen return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) @@ -443,7 +351,7 @@ func (m *Mapper) PeerRemovedResponse( delete(m.peers, uint64(id)) } - resp := m.baseMapResponse(machine) + resp := m.baseMapResponse() resp.PeersRemoved = removed return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) @@ -497,7 +405,7 @@ func (m *Mapper) marshalMapResponse( panic(err) } - now := time.Now().Unix() + now := time.Now().UnixNano() mapResponsePath := path.Join( mPath, @@ -583,7 +491,9 @@ var zstdEncoderPool = &sync.Pool{ }, } -func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse { +// baseMapResponse returns a tailcfg.MapResponse with +// KeepAlive false and ControlTime set to now. +func (m *Mapper) baseMapResponse() tailcfg.MapResponse { now := time.Now() resp := tailcfg.MapResponse{ @@ -591,14 +501,43 @@ func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse { ControlTime: &now, } - // online, err := m.db.ListOnlineMachines(machine) - // if err == nil { - // resp.OnlineChange = online - // } - return resp } +// baseWithConfigMapResponse returns a tailcfg.MapResponse struct +// with the basic configuration from headscale set. +// It is used in for bigger updates, such as full and lite, not +// incremental. +func (m *Mapper) baseWithConfigMapResponse( + machine *types.Machine, + pol *policy.ACLPolicy, +) (*tailcfg.MapResponse, error) { + resp := m.baseMapResponse() + + tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain) + if err != nil { + return nil, err + } + resp.Node = tailnode + + resp.DERPMap = m.derpMap + + resp.Domain = m.baseDomain + + // Do not instruct clients to collect services we do not + // support or do anything with them + resp.CollectServices = "false" + + resp.KeepAlive = false + + resp.Debug = &tailcfg.Debug{ + DisableLogTail: !m.logtail, + RandomizeClientPort: m.randomClientPort, + } + + return &resp, nil +} + func machineMapToList(machines map[uint64]*types.Machine) types.Machines { ret := make(types.Machines, 0) @@ -617,3 +556,67 @@ func filterExpiredAndNotReady(peers types.Machines) types.Machines { return !item.IsExpired() || len(item.Endpoints) > 0 }) } + +// appendPeerChanges mutates a tailcfg.MapResponse with all the +// necessary changes when peers have changed. +func appendPeerChanges( + resp *tailcfg.MapResponse, + + pol *policy.ACLPolicy, + machine *types.Machine, + peers types.Machines, + changed types.Machines, + baseDomain string, + dnsCfg *tailcfg.DNSConfig, +) error { + fullChange := len(peers) == len(changed) + + rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( + pol, + machine, + peers, + ) + if err != nil { + return err + } + + // Filter out peers that have expired. + changed = filterExpiredAndNotReady(changed) + + // If there are filter rules present, see if there are any machines that cannot + // access eachother at all and remove them from the peers. + if len(rules) > 0 { + changed = policy.FilterMachinesByACL(machine, changed, rules) + } + + profiles := generateUserProfiles(machine, changed, baseDomain) + + dnsConfig := generateDNSConfig( + dnsCfg, + baseDomain, + machine, + peers, + ) + + tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain) + if err != nil { + return err + } + + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + + if fullChange { + resp.Peers = tailPeers + } else { + resp.PeersChanged = tailPeers + } + resp.DNSConfig = dnsConfig + resp.PacketFilter = policy.ReduceFilterRules(machine, rules) + resp.UserProfiles = profiles + resp.SSHPolicy = sshPolicy + + return nil +} diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index dd455d1a..c0857f26 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -441,9 +441,11 @@ func Test_fullMapResponse(t *testing.T) { }, }, }, - UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, - ControlTime: &time.Time{}, + UserProfiles: []tailcfg.UserProfile{ + {LoginName: "mini", DisplayName: "mini"}, + }, + SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, + ControlTime: &time.Time{}, Debug: &tailcfg.Debug{ DisableLogTail: true, }, @@ -454,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fullMapResponse( - tt.pol, + mappy := NewMapper( tt.machine, tt.peers, + nil, + false, + tt.derpMap, tt.baseDomain, tt.dnsConfig, - tt.derpMap, tt.logtail, tt.randomClientPort, ) + got, err := mappy.fullMapResponse( + tt.machine, + tt.pol, + ) + if (err != nil) != tt.wantErr { t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index b96f1424..dc9763b1 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -55,6 +55,8 @@ func logPollFunc( // handlePoll is the common code for the legacy and Noise protocols to // managed the poll loop. +// +//nolint:gocyclo func (h *Headscale) handlePoll( writer http.ResponseWriter, ctx context.Context, @@ -67,6 +69,7 @@ func (h *Headscale) handlePoll( // following updates missing var updateChan chan types.StateUpdate if mapRequest.Stream { + h.pollStreamOpenMu.Lock() h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -251,6 +254,8 @@ func (h *Headscale) handlePoll( ctx, cancel := context.WithCancel(ctx) defer cancel() + h.pollStreamOpenMu.Unlock() + for { logInfo("Waiting for update on stream channel") select { diff --git a/integration/general_test.go b/integration/general_test.go index a3e32f71..4de121c7 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) { defer scenario.Shutdown() spec := map[string]int{ - // Omit 1.16.2 (-1) because it does not have the FQDN field - "magicdns1": len(MustTestVersions) - 1, - "magicdns2": len(MustTestVersions) - 1, + "magicdns1": len(MustTestVersions), + "magicdns2": len(MustTestVersions), } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index be0ccfc5..6efe70f6 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -20,10 +20,11 @@ import ( ) const ( - tsicHashLength = 6 - defaultPingCount = 10 - dockerContextPath = "../." - headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" + tsicHashLength = 6 + defaultPingTimeout = 300 * time.Millisecond + defaultPingCount = 10 + dockerContextPath = "../." + headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" ) var ( @@ -591,7 +592,7 @@ func WithPingUntilDirect(direct bool) PingOption { // TODO(kradalby): Make multiping, go routine magic. func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error { args := pingArgs{ - timeout: 300 * time.Millisecond, + timeout: defaultPingTimeout, count: defaultPingCount, direct: true, }