cmd/lopower: get e2e packets working

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-11-02 15:34:12 -07:00 committed by Anton Tolchanov
parent 922d65ed11
commit 44795dea4b

View File

@ -160,10 +160,12 @@ func newLP(ctx context.Context) *lpServer {
logf := log.Printf logf := log.Printf
deviceLogger := &device.Logger{ deviceLogger := &device.Logger{
Verbosef: logger.Discard, Verbosef: logger.Discard,
// Verbosef: logf,
Errorf: logf, Errorf: logf,
} }
lp := &lpServer{ lp := &lpServer{
dir: *confDir, dir: *confDir,
readCh: make(chan *stack.PacketBuffer, 16),
} }
lp.loadConfig() lp.loadConfig()
lp.initNetstack(ctx) lp.initNetstack(ctx)
@ -174,7 +176,6 @@ func newLP(ctx context.Context) *lpServer {
} }
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger) wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
defer wgdev.Close()
lp.d = wgdev lp.d = wgdev
must.Do(wgdev.Up()) must.Do(wgdev.Up())
lp.reconfig() lp.reconfig()
@ -193,6 +194,7 @@ type lpServer struct {
d *device.Device d *device.Device
ns *stack.Stack ns *stack.Stack
linkEP *channel.Endpoint linkEP *channel.Endpoint
readCh chan *stack.PacketBuffer
mu sync.Mutex // protects following mu sync.Mutex // protects following
c *config c *config
@ -203,6 +205,7 @@ type lpServer struct {
const MaxPacketSize = device.MaxContentSize const MaxPacketSize = device.MaxContentSize
func (lp *lpServer) initNetstack(ctx context.Context) error { func (lp *lpServer) initNetstack(ctx context.Context) error {
ns := stack.New(stack.Options{ ns := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol, ipv4.NewProtocol,
@ -284,6 +287,7 @@ func (lp *lpServer) initNetstack(ctx context.Context) error {
if pkt == nil { if pkt == nil {
if ctx.Err() != nil { if ctx.Err() != nil {
// Return without logging. // Return without logging.
log.Printf("linkEP.ReadContext: %v", ctx.Err())
return return
} }
continue continue
@ -293,6 +297,10 @@ func (lp *lpServer) initNetstack(ctx context.Context) error {
pkt.DecRef() pkt.DecRef()
continue continue
} }
select {
case lp.readCh <- pkt:
case <-ctx.Done():
}
} }
}() }()
return nil return nil
@ -309,31 +317,39 @@ func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr {
} }
func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) {
log.Printf("acceptTCP: %v", r.ID())
var wq waiter.Queue var wq waiter.Queue
ep, tcpErr := r.CreateEndpoint(&wq) ep, tcpErr := r.CreateEndpoint(&wq)
if tcpErr != nil { if tcpErr != nil {
log.Printf("CreateEndpoint: %v", tcpErr)
r.Complete(true) r.Complete(true)
return return
} }
log.Printf("created endpoint %v", ep)
defer ep.Close() defer ep.Close()
ep.SocketOptions().SetKeepAlive(true)
reqDetails := r.ID() reqDetails := r.ID()
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
destPort := reqDetails.LocalPort destPort := reqDetails.LocalPort
if !clientRemoteIP.IsValid() { if !clientRemoteIP.IsValid() {
log.Printf("acceptTCP: invalid remote IP %v", reqDetails.RemoteAddress)
r.Complete(true) // sends a RST r.Complete(true) // sends a RST
return return
} }
log.Printf("request from %v to %v:%d", clientRemoteIP, destIP, destPort)
dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) dialCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort)) c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort))
cancel() cancel()
if err != nil { if err != nil {
log.Printf("Dial(%s:%d): %v", destIP, destPort, err)
r.Complete(true) // sends a RST r.Complete(true) // sends a RST
return return
} }
defer c.Close() defer c.Close()
log.Printf("Connected to %s:%d", destIP, destPort)
tc := gonet.NewTCPConn(&wq, ep) tc := gonet.NewTCPConn(&wq, ep)
defer tc.Close() defer tc.Close()
@ -341,13 +357,17 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) {
errc := make(chan error, 2) errc := make(chan error, 2)
go func() { _, err := io.Copy(tc, c); errc <- err }() go func() { _, err := io.Copy(tc, c); errc <- err }()
go func() { _, err := io.Copy(c, tc); errc <- err }() go func() { _, err := io.Copy(c, tc); errc <- err }()
<-errc err = <-errc
if err != nil {
log.Printf("io.Copy: %v", err)
}
} }
func (lp *lpServer) wgConfigForQR() string { func (lp *lpServer) wgConfigForQR() string {
var b strings.Builder var b strings.Builder
privHex, _ := lp.c.Peers[0].PrivKey.MarshalText() p := lp.c.Peers[0]
privHex, _ := p.PrivKey.MarshalText()
privHex = bytes.TrimPrefix(privHex, []byte("privkey:")) privHex = bytes.TrimPrefix(privHex, []byte("privkey:"))
priv := make([]byte, 32) priv := make([]byte, 32)
got, err := hex.Decode(priv, privHex) got, err := hex.Decode(priv, privHex)
@ -358,7 +378,7 @@ func (lp *lpServer) wgConfigForQR() string {
privb64 := base64.StdEncoding.EncodeToString(priv) privb64 := base64.StdEncoding.EncodeToString(priv)
fmt.Fprintf(&b, "[Interface]\nPrivateKey = %s\n", privb64) fmt.Fprintf(&b, "[Interface]\nPrivateKey = %s\n", privb64)
fmt.Fprintf(&b, "Address = %v\n", lp.c.V6) fmt.Fprintf(&b, "Address = %v,%v\n", p.V6, p.V4)
pubBin, _ := lp.c.PrivKey.Public().MarshalBinary() pubBin, _ := lp.c.PrivKey.Public().MarshalBinary()
if len(pubBin) != 34 { if len(pubBin) != 34 {
@ -368,7 +388,7 @@ func (lp *lpServer) wgConfigForQR() string {
pubb64 := base64.StdEncoding.EncodeToString(pubBin) pubb64 := base64.StdEncoding.EncodeToString(pubBin)
fmt.Fprintf(&b, "[Peer]\nPublicKey = %v\n", pubb64) fmt.Fprintf(&b, "[Peer]\nPublicKey = %v\n", pubb64)
fmt.Fprintf(&b, "AllowedIPs = %v\n", tsaddr.TailscaleULARange()) fmt.Fprintf(&b, "AllowedIPs = %v/32,%v/128,%v,%v\n", lp.c.V4, lp.c.V6, tsaddr.TailscaleULARange(), tsaddr.CGNATRange())
fmt.Fprintf(&b, "Endpoint = %v\n", net.JoinHostPort(*wgPubHost, fmt.Sprint(*wgListenPort))) fmt.Fprintf(&b, "Endpoint = %v\n", net.JoinHostPort(*wgPubHost, fmt.Sprint(*wgListenPort)))
return b.String() return b.String()
@ -404,7 +424,6 @@ func (lp *lpServer) serveQR() {
type nsTUN struct { type nsTUN struct {
lp *lpServer lp *lpServer
closeCh chan struct{} closeCh chan struct{}
readCh chan *stack.PacketBuffer
evChan chan tun.Event evChan chan tun.Event
} }
@ -422,7 +441,7 @@ func (t *nsTUN) Read(out [][]byte, sizes []int, offset int) (int, error) {
select { select {
case <-t.closeCh: case <-t.closeCh:
return 0, io.EOF return 0, io.EOF
case resPacket := <-t.readCh: case resPacket := <-t.lp.readCh:
defer resPacket.DecRef() defer resPacket.DecRef()
pkt := out[0][offset:] pkt := out[0][offset:]
n := copy(pkt, resPacket.NetworkHeader().Slice()) n := copy(pkt, resPacket.NetworkHeader().Slice())