mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 21:15:39 +00:00
151 lines
3.2 KiB
Go
151 lines
3.2 KiB
Go
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package stunner
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sort"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"gortc.io/stun"
|
||
|
)
|
||
|
|
||
|
func TestStun(t *testing.T) {
|
||
|
conn1, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
defer conn1.Close()
|
||
|
conn2, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
defer conn2.Close()
|
||
|
stunServers := []string{
|
||
|
conn1.LocalAddr().String(), conn2.LocalAddr().String(),
|
||
|
}
|
||
|
|
||
|
epCh := make(chan string, 16)
|
||
|
|
||
|
localConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
s := &Stunner{
|
||
|
Send: localConn.WriteTo,
|
||
|
Endpoint: func(ep string) { epCh <- ep },
|
||
|
Servers: stunServers,
|
||
|
}
|
||
|
|
||
|
stun1Err := make(chan error)
|
||
|
go func() {
|
||
|
stun1Err <- startSTUN(conn1, s.Receive)
|
||
|
}()
|
||
|
stun2Err := make(chan error)
|
||
|
go func() {
|
||
|
stun2Err <- startSTUNDrop1(conn2, s.Receive)
|
||
|
}()
|
||
|
|
||
|
errCh := make(chan error)
|
||
|
go func() {
|
||
|
errCh <- s.Run(context.Background())
|
||
|
}()
|
||
|
|
||
|
var eps []string
|
||
|
select {
|
||
|
case ep := <-epCh:
|
||
|
eps = append(eps, ep)
|
||
|
case <-time.After(100 * time.Millisecond):
|
||
|
t.Fatal("missing first endpoint response")
|
||
|
}
|
||
|
select {
|
||
|
case ep := <-epCh:
|
||
|
eps = append(eps, ep)
|
||
|
case <-time.After(500 * time.Millisecond):
|
||
|
t.Fatal("missing second endpoint response")
|
||
|
}
|
||
|
sort.Strings(eps)
|
||
|
if want := "1.2.3.4:1234"; eps[0] != want {
|
||
|
t.Errorf("eps[0]=%q, want %q", eps[0], want)
|
||
|
}
|
||
|
if want := "4.5.6.7:4567"; eps[1] != want {
|
||
|
t.Errorf("eps[1]=%q, want %q", eps[1], want)
|
||
|
}
|
||
|
|
||
|
if err := <-errCh; err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func startSTUNDrop1(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error {
|
||
|
if _, _, err := conn.ReadFrom(make([]byte, 1024)); err != nil {
|
||
|
return fmt.Errorf("first stun server read failed: %v", err)
|
||
|
}
|
||
|
req := new(stun.Message)
|
||
|
res := new(stun.Message)
|
||
|
|
||
|
p := make([]byte, 1024)
|
||
|
n, addr, err := conn.ReadFrom(p)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
p = p[:n]
|
||
|
if !stun.IsMessage(p) {
|
||
|
return errors.New("not a STUN message")
|
||
|
}
|
||
|
if _, err := req.Write(p); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
mappedAddr := &stun.XORMappedAddress{
|
||
|
IP: net.ParseIP("1.2.3.4"),
|
||
|
Port: 1234,
|
||
|
}
|
||
|
software := stun.NewSoftware("endpointer")
|
||
|
err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
writeTo(res.Raw, addr.(*net.UDPAddr))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func startSTUN(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error {
|
||
|
req := new(stun.Message)
|
||
|
res := new(stun.Message)
|
||
|
|
||
|
p := make([]byte, 1024)
|
||
|
n, addr, err := conn.ReadFrom(p)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
p = p[:n]
|
||
|
if !stun.IsMessage(p) {
|
||
|
return errors.New("not a STUN message")
|
||
|
}
|
||
|
if _, err := req.Write(p); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
mappedAddr := &stun.XORMappedAddress{
|
||
|
IP: net.ParseIP("4.5.6.7"),
|
||
|
Port: 4567,
|
||
|
}
|
||
|
software := stun.NewSoftware("endpointer")
|
||
|
err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
writeTo(res.Raw, addr.(*net.UDPAddr))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// TODO: test retry timeout (overwrite the retryDurations)
|
||
|
// TODO: test canceling context passed to Run
|
||
|
// TODO: test sending bad packets
|