tstest/natlab/vnet: move some boilerplate to mkPacket helper

No need to make callers specify the redundant IP version or
TTL/HopLimit or EthernetType in the common case. The mkPacket helper
can set those when unset.

And use the mkIPLayer in another place, simplifying some code.

And rename mkPacketErr to just mkPacket, then move mkPacket to
test-only code, as mustPacket.

Updates #13038

Change-Id: Ic216e44dda760c69ab9bfc509370040874a47d30
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2024-08-30 19:23:49 -07:00
committed by Brad Fitzpatrick
parent 7e88d6712e
commit 3d9e3a17fa
2 changed files with 74 additions and 83 deletions

View File

@@ -245,12 +245,9 @@ func (n *network) handleIPPacketFromGvisor(ipRaw []byte) {
return
}
eth := &layers.Ethernet{
SrcMAC: n.mac.HWAddr(),
DstMAC: node.mac.HWAddr(),
EthernetType: flow.etherType(),
SrcMAC: n.mac.HWAddr(),
DstMAC: node.mac.HWAddr(),
}
buffer := gopacket.NewSerializeBuffer()
options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}
sls := []gopacket.SerializableLayer{
eth,
}
@@ -259,21 +256,16 @@ func (n *network) handleIPPacketFromGvisor(ipRaw []byte) {
if !ok {
log.Fatalf("layer %s is not serializable", layer.LayerType().String())
}
switch gl := layer.(type) {
case *layers.TCP:
gl.SetNetworkLayerForChecksum(goPkt.NetworkLayer())
case *layers.UDP:
gl.SetNetworkLayerForChecksum(goPkt.NetworkLayer())
}
sls = append(sls, sl)
}
if err := gopacket.SerializeLayers(buffer, options, sls...); err != nil {
resPkt, err := mkPacket(sls...)
if err != nil {
n.logf("gvisor: serialize error: %v", err)
return
}
if nw, ok := n.writers.Load(node.mac); ok {
nw.write(buffer.Bytes())
nw.write(resPkt)
} else {
n.logf("gvisor write: no writeFunc for %v", node.mac)
}
@@ -1168,9 +1160,8 @@ func (n *network) WriteUDPPacketNoNAT(p UDPPacket) {
}
eth := &layers.Ethernet{
SrcMAC: n.mac.HWAddr(), // of gateway
DstMAC: node.mac.HWAddr(),
EthernetType: p.etherType(),
SrcMAC: n.mac.HWAddr(), // of gateway
DstMAC: node.mac.HWAddr(),
}
ethRaw, err := n.serializedUDPPacket(src, dst, p.Payload, eth)
if err != nil {
@@ -1188,8 +1179,6 @@ type serializableNetworkLayer interface {
func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetworkLayer {
if src.Is4() {
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: proto,
SrcIP: src.AsSlice(),
DstIP: dst.AsSlice(),
@@ -1197,8 +1186,6 @@ func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetwork
}
if src.Is6() {
return &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: proto,
SrcIP: src.AsSlice(),
DstIP: dst.AsSlice(),
@@ -1219,9 +1206,9 @@ func (n *network) serializedUDPPacket(src, dst netip.AddrPort, payload []byte, e
DstPort: layers.UDPPort(dst.Port()),
}
if eth == nil {
return mkPacketErr(ip, udp, gopacket.Payload(payload))
return mkPacket(ip, udp, gopacket.Payload(payload))
} else {
return mkPacketErr(eth, ip, udp, gopacket.Payload(payload))
return mkPacket(eth, ip, udp, gopacket.Payload(payload))
}
}
@@ -1404,9 +1391,8 @@ func (n *network) handleIPv6RouterSolicitation(ep EthernetPacket, rs *layers.ICM
}
n.logf("sending IPv6 router advertisement to %v from %v", eth.DstMAC, eth.SrcMAC)
ip := &layers.IPv6{
Version: 6,
HopLimit: 255,
NextHeader: layers.IPProtocolICMPv6,
HopLimit: 255, // per RFC 4861, 7.1.1 etc (all NDP messages); don't use mkPacket's default of 64
SrcIP: net.ParseIP("fe80::1"),
DstIP: v6.SrcIP,
}
@@ -1431,7 +1417,7 @@ func (n *network) handleIPv6RouterSolicitation(ep EthernetPacket, rs *layers.ICM
},
},
}
pkt, err := mkPacketErr(eth, ip, icmp, ra)
pkt, err := mkPacket(eth, ip, icmp, ra)
if err != nil {
n.logf("serializing ICMPv6 RA: %v", err)
return
@@ -1462,8 +1448,7 @@ func (n *network) handleIPv6NeighborSolicitation(ep EthernetPacket, ns *layers.I
EthernetType: layers.EthernetTypeIPv6,
}
ip := &layers.IPv6{
Version: 6,
HopLimit: 255,
HopLimit: 255, // per RFC 4861, 7.1.1 etc (all NDP messages); don't use mkPacket's default of 64
NextHeader: layers.IPProtocolICMPv6,
SrcIP: ns.TargetAddress,
DstIP: v6.SrcIP,
@@ -1485,7 +1470,7 @@ func (n *network) handleIPv6NeighborSolicitation(ep EthernetPacket, ns *layers.I
Type: layers.ICMPv6OptTargetAddress,
Data: srcMAC.HWAddr(),
})
pkt, err := mkPacketErr(eth, ip, icmp, na)
pkt, err := mkPacket(eth, ip, icmp, na)
if err != nil {
n.logf("serializing ICMPv6 NA: %v", err)
}
@@ -1578,8 +1563,6 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) {
EthernetType: layers.EthernetTypeIPv4, // never IPv6 for DHCP
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: ipLayer.DstIP,
DstIP: ipLayer.SrcIP,
@@ -1588,7 +1571,7 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) {
SrcPort: udpLayer.DstPort,
DstPort: udpLayer.SrcPort,
}
return mkPacketErr(eth, ip, udp, response)
return mkPacket(eth, ip, udp, response)
}
// isDHCPRequest reports whether pkt is a DHCPv4 request.
@@ -1654,20 +1637,6 @@ type ipSrcDst struct {
dst netip.Addr
}
func (f ipSrcDst) etherType() layers.EthernetType {
if f.dst.Is6() {
return layers.EthernetTypeIPv6
}
return layers.EthernetTypeIPv4
}
func (p UDPPacket) etherType() layers.EthernetType {
if p.Dst.Addr().Is6() {
return layers.EthernetTypeIPv6
}
return layers.EthernetTypeIPv4
}
func flow(gp gopacket.Packet) (f ipSrcDst, ok bool) {
if gp == nil {
return f, false
@@ -1778,9 +1747,8 @@ func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) {
// Make reply layers, all reversed.
eth2 := &layers.Ethernet{
SrcMAC: ethLayer.DstMAC,
DstMAC: ethLayer.SrcMAC,
EthernetType: flow.etherType(),
SrcMAC: ethLayer.DstMAC,
DstMAC: ethLayer.SrcMAC,
}
ip2 := mkIPLayer(layers.IPProtocolUDP, flow.dst, flow.src)
udp2 := &layers.UDP{
@@ -1788,7 +1756,7 @@ func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) {
DstPort: udpLayer.SrcPort,
}
resPkt, err := mkPacketErr(eth2, ip2, udp2, response)
resPkt, err := mkPacket(eth2, ip2, udp2, response)
if err != nil {
return nil, err
}
@@ -2180,18 +2148,48 @@ func (c *NodeAgentClient) EnableHostFirewall(ctx context.Context) error {
return nil
}
func mkPacket(layers ...gopacket.SerializableLayer) []byte {
return must.Get(mkPacketErr(layers...))
}
func mkPacketErr(ll ...gopacket.SerializableLayer) ([]byte, error) {
// mkPacket is a serializes a number of layers into a packet.
//
// It's a convenience wrapper around gopacket.SerializeLayers
// that does some things automatically:
//
// * layers.Ethernet.EthernetType is set to IPv4 or IPv6 if not already set
// * layers.IPv4/IPv6 Version is set to 4/6 if not already set
// * layers.IPv4/IPv6 TTL/HopLimit is set to 64 if not already set
// * the TCP/UDP/ICMPv6 checksum is set based on the network layer
//
// The provided layers in ll must be sorted from lowest (e.g. *layers.Ethernet)
// to highest. (Depending on the need, the first layer will be either *layers.Ethernet
// or *layers.IPv4/IPv6).
func mkPacket(ll ...gopacket.SerializableLayer) ([]byte, error) {
var el *layers.Ethernet
var nl gopacket.NetworkLayer
for _, la := range ll {
switch la := la.(type) {
case *layers.IPv4:
nl = la
if el != nil && el.EthernetType == 0 {
el.EthernetType = layers.EthernetTypeIPv4
}
if la.Version == 0 {
la.Version = 4
}
if la.TTL == 0 {
la.TTL = 64
}
case *layers.IPv6:
nl = la
if el != nil && el.EthernetType == 0 {
el.EthernetType = layers.EthernetTypeIPv6
}
if la.Version == 0 {
la.Version = 6
}
if la.HopLimit == 0 {
la.HopLimit = 64
}
case *layers.Ethernet:
el = la
}
}
for _, la := range ll {