From 14abc820333377e88820e2c0b048f788b3e33dbe Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 26 Feb 2020 11:34:01 -0800 Subject: [PATCH] stun: check high bits in Is, add tests Also use new stun.TxID type in stunner. Signed-off-by: Brad Fitzpatrick --- stun/stun.go | 8 +++----- stun/stun_test.go | 26 ++++++++++++++++++++++++++ stunner/stunner.go | 6 +++--- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/stun/stun.go b/stun/stun.go index ea90633f1..7c0e81992 100644 --- a/stun/stun.go +++ b/stun/stun.go @@ -218,9 +218,7 @@ func mappedAddress(b []byte) (addr []byte, port uint16, err error) { // Is reports whether b is a STUN message. func Is(b []byte) bool { - if len(b) < headerLen { - return false // every STUN message must have a 20-byte header - } - // TODO RFC5389 suggests checking the first 2 bits of the header are zero. - return string(b[4:8]) == magicCookie + return len(b) >= headerLen && + b[0]&0b11000000 == 0 && // top two bits must be zero + string(b[4:8]) == magicCookie } diff --git a/stun/stun_test.go b/stun/stun_test.go index 8088d6265..513c05728 100644 --- a/stun/stun_test.go +++ b/stun/stun_test.go @@ -166,3 +166,29 @@ func TestParseResponse(t *testing.T) { }) } } + +func TestIs(t *testing.T) { + const magicCookie = "\x21\x12\xa4\x42" + tests := []struct { + in string + want bool + }{ + {"", false}, + {"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false}, + {"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false}, + {"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true}, + {"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00foo", true}, + // high bits set: + {"\xf0\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false}, + {"\x40\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false}, + // first byte non-zero, but not high bits: + {"\x20\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true}, + } + for i, tt := range tests { + pkt := []byte(tt.in) + got := stun.Is(pkt) + if got != tt.want { + t.Errorf("%d. In(%q (%v)) = %v; want %v", i, pkt, pkt, got, tt.want) + } + } +} diff --git a/stunner/stunner.go b/stunner/stunner.go index 92ba6c354..6c23af38e 100644 --- a/stunner/stunner.go +++ b/stunner/stunner.go @@ -40,7 +40,7 @@ type Stunner struct { type session struct { replied chan struct{} // closed when server responds - tIDs [][12]byte // transaction IDs sent to a server + tIDs []stun.TxID // transaction IDs sent to a server } // Receive delivers a STUN packet to the stunner. @@ -90,7 +90,7 @@ func (s *Stunner) Run(ctx context.Context) error { } for _, server := range s.Servers { // Generate the transaction IDs for this session. - tIDs := make([][12]byte, len(retryDurations)) + tIDs := make([]stun.TxID, len(retryDurations)) for i := range tIDs { if _, err := rand.Read(tIDs[i][:]); err != nil { return fmt.Errorf("stunner: rand failed: %v", err) @@ -147,7 +147,7 @@ func (s *Stunner) runServer(ctx context.Context, server string) { } } -func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error { +func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error { host, port, err := net.SplitHostPort(server) if err != nil { return err