diff --git a/src/ipv6rwc/ipv6rwc.go b/src/ipv6rwc/ipv6rwc.go index 59f4f022..397417a0 100644 --- a/src/ipv6rwc/ipv6rwc.go +++ b/src/ipv6rwc/ipv6rwc.go @@ -2,9 +2,7 @@ package ipv6rwc import ( "crypto/ed25519" - "errors" - "fmt" - "net" + "net/netip" "sync" "time" @@ -19,15 +17,6 @@ import ( const keyStoreTimeout = 2 * time.Minute -/* -// Out-of-band packet types -const ( - typeKeyDummy = iota // nolint:deadcode,varcheck - typeKeyLookup - typeKeyResponse -) -*/ - type keyArray [ed25519.PublicKeySize]byte type keyStore struct { @@ -40,6 +29,7 @@ type keyStore struct { addrBuffer map[address.Address]*buffer subnetToInfo map[address.Subnet]*keyInfo subnetBuffer map[address.Subnet]*buffer + tunnelHelper TunnelHelper mtu uint64 } @@ -59,10 +49,6 @@ func (k *keyStore) init(c *core.Core) { k.core = c k.address = *address.AddrForKey(k.core.PublicKey()) k.subnet = *address.SubnetForKey(k.core.PublicKey()) - /*if err := k.core.SetOutOfBandHandler(k.oobHandler); err != nil { - err = fmt.Errorf("tun.core.SetOutOfBandHander: %w", err) - panic(err) - }*/ k.core.SetPathNotify(func(key ed25519.PublicKey) { k.update(key) }) @@ -182,49 +168,10 @@ func (k *keyStore) resetTimeout(info *keyInfo) { }) } -/* -func (k *keyStore) oobHandler(fromKey, toKey ed25519.PublicKey, data []byte) { // nolint:unused - if len(data) != 1+ed25519.SignatureSize { - return - } - sig := data[1:] - switch data[0] { - case typeKeyLookup: - snet := *address.SubnetForKey(toKey) - if snet == k.subnet && ed25519.Verify(fromKey, toKey[:], sig) { - // This is looking for at least our subnet (possibly our address) - // Send a response - k.sendKeyResponse(fromKey) - } - case typeKeyResponse: - // TODO keep a list of something to match against... - // Ignore the response if it doesn't match anything of interest... - if ed25519.Verify(fromKey, toKey[:], sig) { - k.update(fromKey) - } - } -} -*/ - func (k *keyStore) sendKeyLookup(partial ed25519.PublicKey) { - /* - sig := ed25519.Sign(k.core.PrivateKey(), partial[:]) - bs := append([]byte{typeKeyLookup}, sig...) - //_ = k.core.SendOutOfBand(partial, bs) - _ = bs - */ k.core.SendLookup(partial) } -/* -func (k *keyStore) sendKeyResponse(dest ed25519.PublicKey) { // nolint:unused - sig := ed25519.Sign(k.core.PrivateKey(), dest[:]) - bs := append([]byte{typeKeyResponse}, sig...) - //_ = k.core.SendOutOfBand(dest, bs) - _ = bs -} -*/ - func (k *keyStore) readPC(p []byte) (int, error) { buf := make([]byte, k.core.MTU(), 65535) for { @@ -240,16 +187,22 @@ func (k *keyStore) readPC(p []byte) (int, error) { if len(bs) == 0 { continue } - if bs[0]&0xf0 != 0x60 { - continue // not IPv6 - } - if len(bs) < 40 { + ip4 := bs[0]&0xf0 == 0x40 + ip6 := bs[0]&0xf0 == 0x60 + switch { + case !ip4 && !ip6: + continue + case ip6 && len(bs) < 40: + continue + case ip4 && len(bs) < 20: continue } k.mutex.Lock() mtu := int(k.mtu) + th := k.tunnelHelper k.mutex.Unlock() - if len(bs) > mtu { + switch { + case ip6 && len(bs) > mtu: // Using bs would make it leak off the stack, so copy to buf buf := make([]byte, 512) cn := copy(buf, bs) @@ -261,51 +214,87 @@ func (k *keyStore) readPC(p []byte) (int, error) { _, _ = k.writePC(packet) } continue + case len(bs) > mtu: + continue } var srcAddr, dstAddr address.Address var srcSubnet, dstSubnet address.Subnet - copy(srcAddr[:], bs[8:]) - copy(dstAddr[:], bs[24:]) - copy(srcSubnet[:], bs[8:]) - copy(dstSubnet[:], bs[24:]) - if dstAddr != k.address && dstSubnet != k.subnet { - continue // bad local address/subnet + var addrlen int + switch { + case ip4: + copy(srcAddr[:], bs[12:16]) + addrlen = 4 + case ip6: + copy(srcAddr[:], bs[8:]) + copy(srcSubnet[:], bs[8:]) + copy(dstAddr[:], bs[24:]) + copy(dstSubnet[:], bs[24:]) + addrlen = 16 } - info := k.update(ed25519.PublicKey(from.(iwt.Addr))) - if srcAddr != info.address && srcSubnet != info.subnet { - continue // bad remote address/subnet + srcKey := ed25519.PublicKey(from.(iwt.Addr)) + info := k.update(srcKey) + switch { + case ip6 && (srcAddr == info.address || srcSubnet == info.subnet): + return copy(p, bs), nil + case ip4, ip6: + if th == nil { + continue + } + addr, ok := netip.AddrFromSlice(srcAddr[:addrlen]) + if !ok || !th.InboundAllowed(addr, srcKey) { + continue + } } - n = copy(p, bs) - return n, nil + return copy(p, bs), nil } } func (k *keyStore) writePC(bs []byte) (int, error) { - if bs[0]&0xf0 != 0x60 { - return 0, errors.New("not an IPv6 packet") // not IPv6 + if len(bs) == 0 { + return 0, nil } - if len(bs) < 40 { - strErr := fmt.Sprint("undersized IPv6 packet, length: ", len(bs)) - return 0, errors.New(strErr) + ip4 := bs[0]&0xf0 == 0x40 + ip6 := bs[0]&0xf0 == 0x60 + switch { + case !ip4 && !ip6: + return len(bs), nil + case ip6 && len(bs) < 40: + return len(bs), nil + case ip4 && len(bs) < 20: + return len(bs), nil } - var srcAddr, dstAddr address.Address - var srcSubnet, dstSubnet address.Subnet - copy(srcAddr[:], bs[8:]) - copy(dstAddr[:], bs[24:]) - copy(srcSubnet[:], bs[8:]) - copy(dstSubnet[:], bs[24:]) - if srcAddr != k.address && srcSubnet != k.subnet { - // This happens all the time due to link-local traffic - // Don't send back an error, just drop it - strErr := fmt.Sprint("incorrect source address: ", net.IP(srcAddr[:]).String()) - return 0, errors.New(strErr) + var dstAddr address.Address + var dstSubnet address.Subnet + var addrlen int + switch { + case ip4: + copy(dstAddr[:], bs[16:20]) + addrlen = 4 + case ip6: + copy(dstAddr[:], bs[24:40]) + copy(dstSubnet[:], bs[24:40]) + addrlen = 16 } - if dstAddr.IsValid() { + switch { + case dstAddr.IsValid(): k.sendToAddress(dstAddr, bs) - } else if dstSubnet.IsValid() { + case dstSubnet.IsValid(): k.sendToSubnet(dstSubnet, bs) - } else { - return 0, errors.New("invalid destination address") + default: + k.mutex.Lock() + th := k.tunnelHelper + k.mutex.Unlock() + if th == nil { + return len(bs), nil + } + addr, ok := netip.AddrFromSlice(dstAddr[:addrlen]) + if !ok { + return len(bs), nil + } + if key := th.OutboundAllowed(addr); key != nil && len(key) == ed25519.PublicKeySize { + return k.core.WriteTo(bs, iwt.Addr(key)) + } + return len(bs), nil } return len(bs), nil } @@ -366,3 +355,14 @@ func (rwc *ReadWriteCloser) Close() error { rwc.core.Stop() return err } + +func (rwc *ReadWriteCloser) SetTunnelHelper(h TunnelHelper) { + rwc.mutex.Lock() + defer rwc.mutex.Unlock() + rwc.tunnelHelper = h +} + +type TunnelHelper interface { + InboundAllowed(srcip netip.Addr, src ed25519.PublicKey) bool + OutboundAllowed(dstip netip.Addr) ed25519.PublicKey +}