diff --git a/derp/derp_client.go b/derp/derp_client.go index c584fd554..75a148298 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -16,6 +16,7 @@ import ( "time" "golang.org/x/crypto/nacl/box" + "golang.org/x/time/rate" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -32,8 +33,9 @@ type Client struct { canAckPings bool isProber bool - wmu sync.Mutex // hold while writing to bw - bw *bufio.Writer + wmu sync.Mutex // hold while writing to bw + bw *bufio.Writer + rate *rate.Limiter // if non-nil, rate limiter to use // Owned by Recv: peeked int // bytes to discard on next Recv @@ -217,7 +219,12 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) { c.wmu.Lock() defer c.wmu.Unlock() - + if c.rate != nil { + pktLen := frameHeaderLen + len(dstKey) + len(pkt) + if !c.rate.AllowN(time.Now(), pktLen) { + return nil // drop + } + } if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil { return err } @@ -353,7 +360,22 @@ type PeerPresentMessage key.Public func (PeerPresentMessage) msg() {} // ServerInfoMessage is sent by the server upon first connect. -type ServerInfoMessage struct{} +type ServerInfoMessage struct { + // TokenBucketBytesPerSecond is how many bytes per second the + // server says it will accept, including all framing bytes. + // + // Zero means unspecified. There might be a limit, but the + // client need not try to respect it. + TokenBucketBytesPerSecond int + + // TokenBucketBytesBurst is how many bytes the server will + // allow to burst, temporarily violating + // TokenBucketBytesPerSecond. + // + // Zero means unspecified. There might be a limit, but the + // client need not try to respect it. + TokenBucketBytesBurst int +} func (ServerInfoMessage) msg() {} @@ -475,12 +497,16 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro // needing to wait an RTT to discover the version at startup. // We'd prefer to give the connection to the client (magicsock) // to start writing as soon as possible. - _, err := c.parseServerInfo(b) + si, err := c.parseServerInfo(b) if err != nil { return nil, fmt.Errorf("invalid server info frame: %v", err) } - // TODO: add the results of parseServerInfo to ServerInfoMessage if we ever need it. - return ServerInfoMessage{}, nil + sm := ServerInfoMessage{ + TokenBucketBytesPerSecond: si.TokenBucketBytesPerSecond, + TokenBucketBytesBurst: si.TokenBucketBytesBurst, + } + c.setSendRateLimiter(sm) + return sm, nil case frameKeepAlive: // A one-way keep-alive message that doesn't require an acknowledgement. // This predated framePing/framePong. @@ -537,3 +563,16 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro } } } + +func (c *Client) setSendRateLimiter(sm ServerInfoMessage) { + c.wmu.Lock() + defer c.wmu.Unlock() + + if sm.TokenBucketBytesPerSecond == 0 { + c.rate = nil + } else { + c.rate = rate.NewLimiter( + rate.Limit(sm.TokenBucketBytesPerSecond), + sm.TokenBucketBytesBurst) + } +} diff --git a/derp/derp_server.go b/derp/derp_server.go index c45db299f..cccf4eacd 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -1079,6 +1079,9 @@ func (s *Server) noteClientActivity(c *sclient) { type serverInfo struct { Version int `json:"version,omitempty"` + + TokenBucketBytesPerSecond int `json:",omitempty"` + TokenBucketBytesBurst int `json:",omitempty"` } func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.Public) error { diff --git a/derp/derp_test.go b/derp/derp_test.go index 7219d762e..7434d9b86 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -1244,3 +1244,80 @@ func TestParseSSOutput(t *testing.T) { t.Errorf("parseSSOutput expected non-empty map") } } + +type countWriter struct { + mu sync.Mutex + writes int + bytes int64 +} + +func (w *countWriter) Write(p []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + w.writes++ + w.bytes += int64(len(p)) + return len(p), nil +} + +func (w *countWriter) Stats() (writes int, bytes int64) { + w.mu.Lock() + defer w.mu.Unlock() + return w.writes, w.bytes +} + +func (w *countWriter) ResetStats() { + w.mu.Lock() + defer w.mu.Unlock() + w.writes, w.bytes = 0, 0 +} + +func TestClientSendRateLimiting(t *testing.T) { + cw := new(countWriter) + c := &Client{ + bw: bufio.NewWriter(cw), + } + c.setSendRateLimiter(ServerInfoMessage{}) + + pkt := make([]byte, 1000) + if err := c.send(key.Public{}, pkt); err != nil { + t.Fatal(err) + } + writes1, bytes1 := cw.Stats() + if writes1 != 1 { + t.Errorf("writes = %v, want 1", writes1) + } + + // Flood should all succeed. + cw.ResetStats() + for i := 0; i < 1000; i++ { + if err := c.send(key.Public{}, pkt); err != nil { + t.Fatal(err) + } + } + writes1K, bytes1K := cw.Stats() + if writes1K != 1000 { + t.Logf("writes = %v; want 1000", writes1K) + } + if got, want := bytes1K, bytes1*1000; got != want { + t.Logf("bytes = %v; want %v", got, want) + } + + // Set a rate limiter + cw.ResetStats() + c.setSendRateLimiter(ServerInfoMessage{ + TokenBucketBytesPerSecond: 1, + TokenBucketBytesBurst: int(bytes1 * 2), + }) + for i := 0; i < 1000; i++ { + if err := c.send(key.Public{}, pkt); err != nil { + t.Fatal(err) + } + } + writesLimited, bytesLimited := cw.Stats() + if writesLimited == 0 || writesLimited == writes1K { + t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited) + } + if bytesLimited < bytes1*2 || bytesLimited >= bytes1K { + t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) + } +}