wgengine/magicsock: use cloud metadata to get public IPs

Updates #12774

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I1661b6a2da7966ab667b075894837afd96f4742f
This commit is contained in:
Andrew Dunham
2024-07-10 16:46:31 -05:00
parent 4055b63b9b
commit 9939374c48
5 changed files with 360 additions and 9 deletions

View File

@@ -133,6 +133,9 @@ type Conn struct {
// bind is the wireguard-go conn.Bind for Conn.
bind *connBind
// cloudInfo is used to query cloud metadata services.
cloudInfo *cloudInfo
// ============================================================
// Fields that must be accessed via atomic load/stores.
@@ -425,9 +428,10 @@ func (o *Options) derpActiveFunc() func() {
// newConn is the error-free, network-listening-side-effect-free based
// of NewConn. Mostly for tests.
func newConn() *Conn {
func newConn(logf logger.Logf) *Conn {
discoPrivate := key.NewDisco()
c := &Conn{
logf: logf,
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
derpStarted: make(chan struct{}),
peerLastDerp: make(map[key.NodePublic]int),
@@ -435,6 +439,7 @@ func newConn() *Conn {
discoInfo: make(map[key.DiscoPublic]*discoInfo),
discoPrivate: discoPrivate,
discoPublic: discoPrivate.Public(),
cloudInfo: newCloudInfo(logf),
}
c.discoShort = c.discoPublic.ShortString()
c.bind = &connBind{Conn: c, closed: true}
@@ -462,10 +467,9 @@ func NewConn(opts Options) (*Conn, error) {
return nil, errors.New("magicsock.Options.NetMon must be non-nil")
}
c := newConn()
c := newConn(opts.logf())
c.port.Store(uint32(opts.Port))
c.controlKnobs = opts.ControlKnobs
c.logf = opts.logf()
c.epFunc = opts.endpointsFunc()
c.derpActiveFunc = opts.derpActiveFunc()
c.idleFunc = opts.IdleFunc
@@ -952,6 +956,27 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
addAddr(ap, tailcfg.EndpointExplicitConf)
}
// If we're on a cloud instance, we might have a public IPv4 or IPv6
// address that we can be reached at. Find those, if they exist, and
// add them.
if addrs, err := c.cloudInfo.GetPublicIPs(ctx); err == nil {
var port4, port6 uint16
if addr := c.pconn4.LocalAddr(); addr != nil {
port4 = uint16(addr.Port)
}
if addr := c.pconn6.LocalAddr(); addr != nil {
port6 = uint16(addr.Port)
}
for _, addr := range addrs {
if addr.Is4() && port4 > 0 {
addAddr(netip.AddrPortFrom(addr, port4), tailcfg.EndpointLocal)
} else if addr.Is6() && port6 > 0 {
addAddr(netip.AddrPortFrom(addr, port6), tailcfg.EndpointLocal)
}
}
}
// Update our set of endpoints by adding any endpoints that we
// previously found but haven't expired yet. This also updates the
// cache with the set of endpoints discovered in this function.