mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-23 01:11:40 +00:00
wgengine/magicsock: use cloud metadata to get public IPs
Updates #12774 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I1661b6a2da7966ab667b075894837afd96f4742f
This commit is contained in:
parent
4055b63b9b
commit
9939374c48
182
wgengine/magicsock/cloudinfo.go
Normal file
182
wgengine/magicsock/cloudinfo.go
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build !(ios || android || js)
|
||||||
|
|
||||||
|
package magicsock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
"tailscale.com/util/cloudenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxCloudInfoWait = 2 * time.Second
|
||||||
|
|
||||||
|
type cloudInfo struct {
|
||||||
|
client http.Client
|
||||||
|
logf logger.Logf
|
||||||
|
|
||||||
|
// The following parameters are fixed for the lifetime of the cloudInfo
|
||||||
|
// object, but are used for testing.
|
||||||
|
cloud cloudenv.Cloud
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCloudInfo(logf logger.Logf) *cloudInfo {
|
||||||
|
tr := &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: maxCloudInfoWait,
|
||||||
|
}).Dial,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cloudInfo{
|
||||||
|
client: http.Client{Transport: tr},
|
||||||
|
logf: logf,
|
||||||
|
cloud: cloudenv.Get(),
|
||||||
|
endpoint: "http://" + cloudenv.CommonNonRoutableMetadataIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPublicIPs returns any public IPs attached to the current cloud instance,
|
||||||
|
// if the tailscaled process is running in a known cloud and there are any such
|
||||||
|
// IPs present.
|
||||||
|
func (ci *cloudInfo) GetPublicIPs(ctx context.Context) ([]netip.Addr, error) {
|
||||||
|
switch ci.cloud {
|
||||||
|
case cloudenv.AWS:
|
||||||
|
ret, err := ci.getAWS(ctx)
|
||||||
|
ci.logf("[v1] cloudinfo.GetPublicIPs: AWS: %v, %v", ret, err)
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAWSMetadata makes a request to the AWS metadata service at the given
|
||||||
|
// path, authenticating with the provided IMDSv2 token. The returned metadata
|
||||||
|
// is split by newline and returned as a slice.
|
||||||
|
func (ci *cloudInfo) getAWSMetadata(ctx context.Context, token, path string) ([]string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", ci.endpoint+path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating request to %q: %w", path, err)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-aws-ec2-metadata-token", token)
|
||||||
|
|
||||||
|
resp, err := ci.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("making request to metadata service %q: %w", path, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
// Good
|
||||||
|
case http.StatusNotFound:
|
||||||
|
// Nothing found, but this isn't an error; just return
|
||||||
|
return nil, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading response body for %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Split(strings.TrimSpace(string(body)), "\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAWS returns all public IPv4 and IPv6 addresses present in the AWS instance metadata.
|
||||||
|
func (ci *cloudInfo) getAWS(ctx context.Context) ([]netip.Addr, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, maxCloudInfoWait)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Get a token so we can query the metadata service.
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "PUT", ci.endpoint+"/latest/api/token", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating token request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Aws-Ec2-Metadata-Token-Ttl-Seconds", "10")
|
||||||
|
|
||||||
|
resp, err := ci.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("making token request to metadata service: %w", err)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading token response body: %w", err)
|
||||||
|
}
|
||||||
|
token := string(body)
|
||||||
|
|
||||||
|
server := resp.Header.Get("Server")
|
||||||
|
if server != "EC2ws" {
|
||||||
|
return nil, fmt.Errorf("unexpected server header: %q", server)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over all interfaces and get their public IP addresses, both IPv4 and IPv6.
|
||||||
|
macAddrs, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting interface MAC addresses: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
addrs []netip.Addr
|
||||||
|
errs []error
|
||||||
|
)
|
||||||
|
|
||||||
|
addAddr := func(addr string) {
|
||||||
|
ip, err := netip.ParseAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("parsing IP address %q: %w", addr, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addrs = append(addrs, ip)
|
||||||
|
}
|
||||||
|
for _, mac := range macAddrs {
|
||||||
|
ips, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/public-ipv4s")
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("getting IPv4 addresses for %q: %w", mac, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
addAddr(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try querying for IPv6 addresses.
|
||||||
|
ips, err = ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/ipv6s")
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("getting IPv6 addresses for %q: %w", mac, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, ip := range ips {
|
||||||
|
addAddr(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the returned addresses for determinism.
|
||||||
|
slices.SortFunc(addrs, func(a, b netip.Addr) int {
|
||||||
|
return a.Compare(b)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Preferentially return any addresses we found, even if there were errors.
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return nil, fmt.Errorf("getting IP addresses: %w", errors.Join(errs...))
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
23
wgengine/magicsock/cloudinfo_nocloud.go
Normal file
23
wgengine/magicsock/cloudinfo_nocloud.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build ios || android || js
|
||||||
|
|
||||||
|
package magicsock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cloudInfo struct{}
|
||||||
|
|
||||||
|
func newCloudInfo(_ logger.Logf) *cloudInfo {
|
||||||
|
return &cloudInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ci *cloudInfo) GetPublicIPs(_ context.Context) ([]netip.Addr, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
123
wgengine/magicsock/cloudinfo_test.go
Normal file
123
wgengine/magicsock/cloudinfo_test.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package magicsock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/util/cloudenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCloudInfo_AWS(t *testing.T) {
|
||||||
|
const (
|
||||||
|
mac1 = "06:1d:00:00:00:00"
|
||||||
|
mac2 = "06:1d:00:00:00:01"
|
||||||
|
publicV4 = "1.2.3.4"
|
||||||
|
otherV4_1 = "5.6.7.8"
|
||||||
|
otherV4_2 = "11.12.13.14"
|
||||||
|
v6addr = "2001:db8::1"
|
||||||
|
|
||||||
|
macsPrefix = "/latest/meta-data/network/interfaces/macs/"
|
||||||
|
)
|
||||||
|
// Launch a fake AWS IMDS server
|
||||||
|
fake := &fakeIMDS{
|
||||||
|
tb: t,
|
||||||
|
paths: map[string]string{
|
||||||
|
macsPrefix: mac1 + "\n" + mac2,
|
||||||
|
// This is the "main" public IP address for the instance
|
||||||
|
macsPrefix + mac1 + "/public-ipv4s": publicV4,
|
||||||
|
|
||||||
|
// There's another interface with two public IPs
|
||||||
|
// attached to it and an IPv6 address, all of which we
|
||||||
|
// should discover.
|
||||||
|
macsPrefix + mac2 + "/public-ipv4s": otherV4_1 + "\n" + otherV4_2,
|
||||||
|
macsPrefix + mac2 + "/ipv6s": v6addr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(fake)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
ci := newCloudInfo(t.Logf)
|
||||||
|
ci.cloud = cloudenv.AWS
|
||||||
|
ci.endpoint = srv.URL
|
||||||
|
|
||||||
|
ips, err := ci.GetPublicIPs(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantIPs := []netip.Addr{
|
||||||
|
netip.MustParseAddr(publicV4),
|
||||||
|
netip.MustParseAddr(otherV4_1),
|
||||||
|
netip.MustParseAddr(otherV4_2),
|
||||||
|
netip.MustParseAddr(v6addr),
|
||||||
|
}
|
||||||
|
if !slices.Equal(ips, wantIPs) {
|
||||||
|
t.Fatalf("got %v, want %v", ips, wantIPs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloudInfo_AWSNotPublic(t *testing.T) {
|
||||||
|
returns404 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "PUT" && r.URL.Path == "/latest/api/token" {
|
||||||
|
w.Header().Set("Server", "EC2ws")
|
||||||
|
w.Write([]byte("fake-imds-token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
srv := httptest.NewServer(returns404)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
ci := newCloudInfo(t.Logf)
|
||||||
|
ci.cloud = cloudenv.AWS
|
||||||
|
ci.endpoint = srv.URL
|
||||||
|
|
||||||
|
// If the IMDS server doesn't return any public IPs, it's not an error
|
||||||
|
// and we should just get an empty list.
|
||||||
|
ips, err := ci.GetPublicIPs(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(ips) != 0 {
|
||||||
|
t.Fatalf("got %v, want none", ips)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeIMDS struct {
|
||||||
|
tb testing.TB
|
||||||
|
paths map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeIMDS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
f.tb.Logf("%s %s", r.Method, r.URL.Path)
|
||||||
|
path := r.URL.Path
|
||||||
|
|
||||||
|
// Handle the /latest/api/token case
|
||||||
|
const token = "fake-imds-token"
|
||||||
|
if r.Method == "PUT" && path == "/latest/api/token" {
|
||||||
|
w.Header().Set("Server", "EC2ws")
|
||||||
|
w.Write([]byte(token))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, require the IMDSv2 token to be set
|
||||||
|
if r.Header.Get("X-aws-ec2-metadata-token") != token {
|
||||||
|
f.tb.Errorf("missing or invalid IMDSv2 token")
|
||||||
|
http.Error(w, "missing or invalid IMDSv2 token", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := f.paths[path]; ok {
|
||||||
|
w.Write([]byte(v))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
@ -133,6 +133,9 @@ type Conn struct {
|
|||||||
// bind is the wireguard-go conn.Bind for Conn.
|
// bind is the wireguard-go conn.Bind for Conn.
|
||||||
bind *connBind
|
bind *connBind
|
||||||
|
|
||||||
|
// cloudInfo is used to query cloud metadata services.
|
||||||
|
cloudInfo *cloudInfo
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
// Fields that must be accessed via atomic load/stores.
|
// Fields that must be accessed via atomic load/stores.
|
||||||
|
|
||||||
@ -425,9 +428,10 @@ func (o *Options) derpActiveFunc() func() {
|
|||||||
|
|
||||||
// newConn is the error-free, network-listening-side-effect-free based
|
// newConn is the error-free, network-listening-side-effect-free based
|
||||||
// of NewConn. Mostly for tests.
|
// of NewConn. Mostly for tests.
|
||||||
func newConn() *Conn {
|
func newConn(logf logger.Logf) *Conn {
|
||||||
discoPrivate := key.NewDisco()
|
discoPrivate := key.NewDisco()
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
|
logf: logf,
|
||||||
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
|
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
|
||||||
derpStarted: make(chan struct{}),
|
derpStarted: make(chan struct{}),
|
||||||
peerLastDerp: make(map[key.NodePublic]int),
|
peerLastDerp: make(map[key.NodePublic]int),
|
||||||
@ -435,6 +439,7 @@ func newConn() *Conn {
|
|||||||
discoInfo: make(map[key.DiscoPublic]*discoInfo),
|
discoInfo: make(map[key.DiscoPublic]*discoInfo),
|
||||||
discoPrivate: discoPrivate,
|
discoPrivate: discoPrivate,
|
||||||
discoPublic: discoPrivate.Public(),
|
discoPublic: discoPrivate.Public(),
|
||||||
|
cloudInfo: newCloudInfo(logf),
|
||||||
}
|
}
|
||||||
c.discoShort = c.discoPublic.ShortString()
|
c.discoShort = c.discoPublic.ShortString()
|
||||||
c.bind = &connBind{Conn: c, closed: true}
|
c.bind = &connBind{Conn: c, closed: true}
|
||||||
@ -462,10 +467,9 @@ func NewConn(opts Options) (*Conn, error) {
|
|||||||
return nil, errors.New("magicsock.Options.NetMon must be non-nil")
|
return nil, errors.New("magicsock.Options.NetMon must be non-nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newConn()
|
c := newConn(opts.logf())
|
||||||
c.port.Store(uint32(opts.Port))
|
c.port.Store(uint32(opts.Port))
|
||||||
c.controlKnobs = opts.ControlKnobs
|
c.controlKnobs = opts.ControlKnobs
|
||||||
c.logf = opts.logf()
|
|
||||||
c.epFunc = opts.endpointsFunc()
|
c.epFunc = opts.endpointsFunc()
|
||||||
c.derpActiveFunc = opts.derpActiveFunc()
|
c.derpActiveFunc = opts.derpActiveFunc()
|
||||||
c.idleFunc = opts.IdleFunc
|
c.idleFunc = opts.IdleFunc
|
||||||
@ -952,6 +956,27 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
|
|||||||
addAddr(ap, tailcfg.EndpointExplicitConf)
|
addAddr(ap, tailcfg.EndpointExplicitConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we're on a cloud instance, we might have a public IPv4 or IPv6
|
||||||
|
// address that we can be reached at. Find those, if they exist, and
|
||||||
|
// add them.
|
||||||
|
if addrs, err := c.cloudInfo.GetPublicIPs(ctx); err == nil {
|
||||||
|
var port4, port6 uint16
|
||||||
|
if addr := c.pconn4.LocalAddr(); addr != nil {
|
||||||
|
port4 = uint16(addr.Port)
|
||||||
|
}
|
||||||
|
if addr := c.pconn6.LocalAddr(); addr != nil {
|
||||||
|
port6 = uint16(addr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if addr.Is4() && port4 > 0 {
|
||||||
|
addAddr(netip.AddrPortFrom(addr, port4), tailcfg.EndpointLocal)
|
||||||
|
} else if addr.Is6() && port6 > 0 {
|
||||||
|
addAddr(netip.AddrPortFrom(addr, port6), tailcfg.EndpointLocal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update our set of endpoints by adding any endpoints that we
|
// Update our set of endpoints by adding any endpoints that we
|
||||||
// previously found but haven't expired yet. This also updates the
|
// previously found but haven't expired yet. This also updates the
|
||||||
// cache with the set of endpoints discovered in this function.
|
// cache with the set of endpoints discovered in this function.
|
||||||
|
@ -452,7 +452,7 @@ func TestPickDERPFallback(t *testing.T) {
|
|||||||
tstest.PanicOnLog()
|
tstest.PanicOnLog()
|
||||||
tstest.ResourceCheck(t)
|
tstest.ResourceCheck(t)
|
||||||
|
|
||||||
c := newConn()
|
c := newConn(t.Logf)
|
||||||
dm := &tailcfg.DERPMap{
|
dm := &tailcfg.DERPMap{
|
||||||
Regions: map[int]*tailcfg.DERPRegion{
|
Regions: map[int]*tailcfg.DERPRegion{
|
||||||
1: {},
|
1: {},
|
||||||
@ -483,7 +483,7 @@ func TestPickDERPFallback(t *testing.T) {
|
|||||||
// distribution over nodes works.
|
// distribution over nodes works.
|
||||||
got := map[int]int{}
|
got := map[int]int{}
|
||||||
for range 50 {
|
for range 50 {
|
||||||
c = newConn()
|
c = newConn(t.Logf)
|
||||||
c.derpMap = dm
|
c.derpMap = dm
|
||||||
got[c.pickDERPFallback()]++
|
got[c.pickDERPFallback()]++
|
||||||
}
|
}
|
||||||
@ -1185,8 +1185,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDiscoMessage(t *testing.T) {
|
func TestDiscoMessage(t *testing.T) {
|
||||||
c := newConn()
|
c := newConn(t.Logf)
|
||||||
c.logf = t.Logf
|
|
||||||
c.privateKey = key.NewNode()
|
c.privateKey = key.NewNode()
|
||||||
|
|
||||||
peer1Pub := c.DiscoPublicKey()
|
peer1Pub := c.DiscoPublicKey()
|
||||||
@ -3161,8 +3160,7 @@ func TestMaybeSetNearestDERP(t *testing.T) {
|
|||||||
for _, tt := range testCases {
|
for _, tt := range testCases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ht := new(health.Tracker)
|
ht := new(health.Tracker)
|
||||||
c := newConn()
|
c := newConn(t.Logf)
|
||||||
c.logf = t.Logf
|
|
||||||
c.myDerp = tt.old
|
c.myDerp = tt.old
|
||||||
c.derpMap = derpMap
|
c.derpMap = derpMap
|
||||||
c.health = ht
|
c.health = ht
|
||||||
|
Loading…
x
Reference in New Issue
Block a user