cmd/tailscaled,net/tstun: fix data race on start-up in TUN mode

Fixes #7894

Change-Id: Ice3f8019405714dd69d02bc07694f3872bb598b8

Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-10-13 19:41:10 +00:00 committed by Brad Fitzpatrick
parent 5c555cdcbb
commit 5297bd2cff
7 changed files with 42 additions and 2 deletions

View File

@ -540,6 +540,10 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID
} }
sys.Set(store) sys.Set(store)
if w, ok := sys.Tun.GetOK(); ok {
w.Start()
}
lb, err := ipnlocal.NewLocalBackend(logf, logID, sys, opts.LoginFlags) lb, err := ipnlocal.NewLocalBackend(logf, logID, sys, opts.LoginFlags)
if err != nil { if err != nil {
return nil, fmt.Errorf("ipnlocal.NewLocalBackend: %w", err) return nil, fmt.Errorf("ipnlocal.NewLocalBackend: %w", err)

View File

@ -55,3 +55,4 @@ func (t *fakeTUN) MTU() (int, error) { return 1500, nil }
func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil }
func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan }
func (t *fakeTUN) BatchSize() int { return 1 } func (t *fakeTUN) BatchSize() int { return 1 }
func (t *fakeTUN) IsFakeTun() bool { return true }

View File

@ -78,6 +78,9 @@
type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response
// Wrapper augments a tun.Device with packet filtering and injection. // Wrapper augments a tun.Device with packet filtering and injection.
//
// A Wrapper starts in a "corked" mode where Read calls are blocked
// until the Wrapper's Start method is called.
type Wrapper struct { type Wrapper struct {
logf logger.Logf logf logger.Logf
limitedLogf logger.Logf // aggressively rate-limited logf used for potentially high volume errors limitedLogf logger.Logf // aggressively rate-limited logf used for potentially high volume errors
@ -85,6 +88,9 @@ type Wrapper struct {
tdev tun.Device tdev tun.Device
isTAP bool // whether tdev is a TAP device isTAP bool // whether tdev is a TAP device
started atomic.Bool // whether Start has been called
startCh chan struct{} // closed in Start
closeOnce sync.Once closeOnce sync.Once
// lastActivityAtomic is read/written atomically. // lastActivityAtomic is read/written atomically.
@ -219,6 +225,16 @@ type setWrapperer interface {
setWrapper(*Wrapper) setWrapper(*Wrapper)
} }
// Start unblocks any Wrapper.Read calls that have already started
// and makes the Wrapper functional.
//
// Start must be called exactly once after the various Tailscale
// subsystems have been wired up to each other.
func (w *Wrapper) Start() {
w.started.Store(true)
close(w.startCh)
}
func WrapTAP(logf logger.Logf, tdev tun.Device) *Wrapper { func WrapTAP(logf logger.Logf, tdev tun.Device) *Wrapper {
return wrap(logf, tdev, true) return wrap(logf, tdev, true)
} }
@ -244,6 +260,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper {
eventsOther: make(chan tun.Event), eventsOther: make(chan tun.Event),
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
filterFlags: filter.LogAccepts | filter.LogDrops, filterFlags: filter.LogAccepts | filter.LogDrops,
startCh: make(chan struct{}),
} }
w.vectorBuffer = make([][]byte, tdev.BatchSize()) w.vectorBuffer = make([][]byte, tdev.BatchSize())
@ -309,6 +326,9 @@ func (t *Wrapper) isSelfDisco(p *packet.Parsed) bool {
func (t *Wrapper) Close() error { func (t *Wrapper) Close() error {
var err error var err error
t.closeOnce.Do(func() { t.closeOnce.Do(func() {
if t.started.CompareAndSwap(false, true) {
close(t.startCh)
}
close(t.closed) close(t.closed)
t.bufferConsumedMu.Lock() t.bufferConsumedMu.Lock()
t.bufferConsumedClosed = true t.bufferConsumedClosed = true
@ -836,6 +856,9 @@ func (t *Wrapper) IdleDuration() time.Duration {
} }
func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
if !t.started.Load() {
<-t.startCh
}
// packet from OS read and sent to WG // packet from OS read and sent to WG
res, ok := <-t.vectorOutbound res, ok := <-t.vectorOutbound
if !ok { if !ok {

View File

@ -178,6 +178,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper
} else { } else {
tun.disableFilter = true tun.disableFilter = true
} }
tun.Start()
return chtun, tun return chtun, tun
} }

View File

@ -47,6 +47,10 @@ type System struct {
StateStore SubSystem[ipn.StateStore] StateStore SubSystem[ipn.StateStore]
Netstack SubSystem[NetstackImpl] // actually a *netstack.Impl Netstack SubSystem[NetstackImpl] // actually a *netstack.Impl
// onlyNetstack is whether the Tun value is a fake TUN device
// and we're using netstack for everything.
onlyNetstack bool
controlKnobs controlknobs.Knobs controlKnobs controlknobs.Knobs
proxyMap proxymap.Mapper proxyMap proxymap.Mapper
} }
@ -74,6 +78,12 @@ func (s *System) Set(v any) {
case router.Router: case router.Router:
s.Router.Set(v) s.Router.Set(v)
case *tstun.Wrapper: case *tstun.Wrapper:
type ft interface {
IsFakeTun() bool
}
if _, ok := v.Unwrap().(ft); ok {
s.onlyNetstack = true
}
s.Tun.Set(v) s.Tun.Set(v)
case *magicsock.Conn: case *magicsock.Conn:
s.MagicSock.Set(v) s.MagicSock.Set(v)
@ -97,8 +107,7 @@ func (s *System) IsNetstackRouter() bool {
// IsNetstack reports whether Tailscale is running as a netstack-based TUN-free engine. // IsNetstack reports whether Tailscale is running as a netstack-based TUN-free engine.
func (s *System) IsNetstack() bool { func (s *System) IsNetstack() bool {
name, _ := s.Tun.Get().Name() return s.onlyNetstack
return name == tstun.FakeTUNName
} }
// ControlKnobs returns the control knobs for this node. // ControlKnobs returns the control knobs for this node.

View File

@ -530,6 +530,7 @@ func (s *Server) start() (reterr error) {
if err != nil { if err != nil {
return fmt.Errorf("netstack.Create: %w", err) return fmt.Errorf("netstack.Create: %w", err)
} }
sys.Tun.Get().Start()
sys.Set(ns) sys.Set(ns)
ns.ProcessLocalIPs = true ns.ProcessLocalIPs = true
ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow

View File

@ -184,6 +184,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen
tun := tuntest.NewChannelTUN() tun := tuntest.NewChannelTUN()
tsTun := tstun.Wrap(logf, tun.TUN()) tsTun := tstun.Wrap(logf, tun.TUN())
tsTun.SetFilter(filter.NewAllowAllForTest(logf)) tsTun.SetFilter(filter.NewAllowAllForTest(logf))
tsTun.Start()
wgLogger := wglog.NewLogger(logf) wgLogger := wglog.NewLogger(logf)
dev := wgcfg.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger) dev := wgcfg.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger)