backend/limiter: replace juju/ratelimit with x/time/rate

This commit is contained in:
Michael Eischer
2023-10-01 22:45:18 +02:00
parent c635e30e3f
commit f750aa8dfb
4 changed files with 85 additions and 14 deletions

View File

@@ -1,15 +1,16 @@
package limiter
import (
"context"
"io"
"net/http"
"github.com/juju/ratelimit"
"golang.org/x/time/rate"
)
type staticLimiter struct {
upstream *ratelimit.Bucket
downstream *ratelimit.Bucket
upstream *rate.Limiter
downstream *rate.Limiter
}
// Limits represents static upload and download limits.
@@ -23,16 +24,16 @@ type Limits struct {
// download rate cap
func NewStaticLimiter(l Limits) Limiter {
var (
upstreamBucket *ratelimit.Bucket
downstreamBucket *ratelimit.Bucket
upstreamBucket *rate.Limiter
downstreamBucket *rate.Limiter
)
if l.UploadKb > 0 {
upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.UploadKb), int64(toByteRate(l.UploadKb)))
upstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.UploadKb)), int(toByteRate(l.UploadKb)))
}
if l.DownloadKb > 0 {
downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.DownloadKb), int64(toByteRate(l.DownloadKb)))
downstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.DownloadKb)), int(toByteRate(l.DownloadKb)))
}
return staticLimiter{
@@ -95,18 +96,55 @@ func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper {
})
}
func (l staticLimiter) limitReader(r io.Reader, b *ratelimit.Bucket) io.Reader {
func (l staticLimiter) limitReader(r io.Reader, b *rate.Limiter) io.Reader {
if b == nil {
return r
}
return ratelimit.Reader(r, b)
return &rateLimitedReader{r, b}
}
func (l staticLimiter) limitWriter(w io.Writer, b *ratelimit.Bucket) io.Writer {
type rateLimitedReader struct {
reader io.Reader
bucket *rate.Limiter
}
func (r *rateLimitedReader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if err := consumeTokens(n, r.bucket); err != nil {
return n, err
}
return n, err
}
func (l staticLimiter) limitWriter(w io.Writer, b *rate.Limiter) io.Writer {
if b == nil {
return w
}
return ratelimit.Writer(w, b)
return &rateLimitedWriter{w, b}
}
type rateLimitedWriter struct {
writer io.Writer
bucket *rate.Limiter
}
func (w *rateLimitedWriter) Write(buf []byte) (int, error) {
if err := consumeTokens(len(buf), w.bucket); err != nil {
return 0, err
}
return w.writer.Write(buf)
}
func consumeTokens(tokens int, bucket *rate.Limiter) error {
// bucket allows waiting for at most Burst() tokens at once
maxWait := bucket.Burst()
for tokens > maxWait {
if err := bucket.WaitN(context.Background(), maxWait); err != nil {
return err
}
tokens -= maxWait
}
return bucket.WaitN(context.Background(), tokens)
}
func toByteRate(val int) float64 {

View File

@@ -9,6 +9,7 @@ import (
"testing"
"github.com/restic/restic/internal/test"
"golang.org/x/time/rate"
)
func TestLimiterWrapping(t *testing.T) {
@@ -33,6 +34,38 @@ func TestLimiterWrapping(t *testing.T) {
}
}
func TestReadLimiter(t *testing.T) {
reader := bytes.NewReader(make([]byte, 300))
limiter := rate.NewLimiter(rate.Limit(10000), int(100))
limReader := rateLimitedReader{reader, limiter}
n, err := limReader.Read([]byte{})
test.OK(t, err)
test.Equals(t, n, 0)
n, err = limReader.Read(make([]byte, 300))
test.OK(t, err)
test.Equals(t, n, 300)
n, err = limReader.Read([]byte{})
test.Equals(t, err, io.EOF)
test.Equals(t, n, 0)
}
func TestWriteLimiter(t *testing.T) {
writer := &bytes.Buffer{}
limiter := rate.NewLimiter(rate.Limit(10000), int(100))
limReader := rateLimitedWriter{writer, limiter}
n, err := limReader.Write([]byte{})
test.OK(t, err)
test.Equals(t, n, 0)
n, err = limReader.Write(make([]byte, 300))
test.OK(t, err)
test.Equals(t, n, 300)
}
type tracedReadCloser struct {
io.Reader
Closed bool