mirror of
https://github.com/restic/restic.git
synced 2025-08-22 17:41:04 +00:00
backend/limiter: replace juju/ratelimit with x/time/rate
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/juju/ratelimit"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type staticLimiter struct {
|
||||
upstream *ratelimit.Bucket
|
||||
downstream *ratelimit.Bucket
|
||||
upstream *rate.Limiter
|
||||
downstream *rate.Limiter
|
||||
}
|
||||
|
||||
// Limits represents static upload and download limits.
|
||||
@@ -23,16 +24,16 @@ type Limits struct {
|
||||
// download rate cap
|
||||
func NewStaticLimiter(l Limits) Limiter {
|
||||
var (
|
||||
upstreamBucket *ratelimit.Bucket
|
||||
downstreamBucket *ratelimit.Bucket
|
||||
upstreamBucket *rate.Limiter
|
||||
downstreamBucket *rate.Limiter
|
||||
)
|
||||
|
||||
if l.UploadKb > 0 {
|
||||
upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.UploadKb), int64(toByteRate(l.UploadKb)))
|
||||
upstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.UploadKb)), int(toByteRate(l.UploadKb)))
|
||||
}
|
||||
|
||||
if l.DownloadKb > 0 {
|
||||
downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.DownloadKb), int64(toByteRate(l.DownloadKb)))
|
||||
downstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.DownloadKb)), int(toByteRate(l.DownloadKb)))
|
||||
}
|
||||
|
||||
return staticLimiter{
|
||||
@@ -95,18 +96,55 @@ func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper {
|
||||
})
|
||||
}
|
||||
|
||||
func (l staticLimiter) limitReader(r io.Reader, b *ratelimit.Bucket) io.Reader {
|
||||
func (l staticLimiter) limitReader(r io.Reader, b *rate.Limiter) io.Reader {
|
||||
if b == nil {
|
||||
return r
|
||||
}
|
||||
return ratelimit.Reader(r, b)
|
||||
return &rateLimitedReader{r, b}
|
||||
}
|
||||
|
||||
func (l staticLimiter) limitWriter(w io.Writer, b *ratelimit.Bucket) io.Writer {
|
||||
type rateLimitedReader struct {
|
||||
reader io.Reader
|
||||
bucket *rate.Limiter
|
||||
}
|
||||
|
||||
func (r *rateLimitedReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
if err := consumeTokens(n, r.bucket); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (l staticLimiter) limitWriter(w io.Writer, b *rate.Limiter) io.Writer {
|
||||
if b == nil {
|
||||
return w
|
||||
}
|
||||
return ratelimit.Writer(w, b)
|
||||
return &rateLimitedWriter{w, b}
|
||||
}
|
||||
|
||||
type rateLimitedWriter struct {
|
||||
writer io.Writer
|
||||
bucket *rate.Limiter
|
||||
}
|
||||
|
||||
func (w *rateLimitedWriter) Write(buf []byte) (int, error) {
|
||||
if err := consumeTokens(len(buf), w.bucket); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.writer.Write(buf)
|
||||
}
|
||||
|
||||
func consumeTokens(tokens int, bucket *rate.Limiter) error {
|
||||
// bucket allows waiting for at most Burst() tokens at once
|
||||
maxWait := bucket.Burst()
|
||||
for tokens > maxWait {
|
||||
if err := bucket.WaitN(context.Background(), maxWait); err != nil {
|
||||
return err
|
||||
}
|
||||
tokens -= maxWait
|
||||
}
|
||||
return bucket.WaitN(context.Background(), tokens)
|
||||
}
|
||||
|
||||
func toByteRate(val int) float64 {
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/restic/restic/internal/test"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func TestLimiterWrapping(t *testing.T) {
|
||||
@@ -33,6 +34,38 @@ func TestLimiterWrapping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLimiter(t *testing.T) {
|
||||
reader := bytes.NewReader(make([]byte, 300))
|
||||
limiter := rate.NewLimiter(rate.Limit(10000), int(100))
|
||||
limReader := rateLimitedReader{reader, limiter}
|
||||
|
||||
n, err := limReader.Read([]byte{})
|
||||
test.OK(t, err)
|
||||
test.Equals(t, n, 0)
|
||||
|
||||
n, err = limReader.Read(make([]byte, 300))
|
||||
test.OK(t, err)
|
||||
test.Equals(t, n, 300)
|
||||
|
||||
n, err = limReader.Read([]byte{})
|
||||
test.Equals(t, err, io.EOF)
|
||||
test.Equals(t, n, 0)
|
||||
}
|
||||
|
||||
func TestWriteLimiter(t *testing.T) {
|
||||
writer := &bytes.Buffer{}
|
||||
limiter := rate.NewLimiter(rate.Limit(10000), int(100))
|
||||
limReader := rateLimitedWriter{writer, limiter}
|
||||
|
||||
n, err := limReader.Write([]byte{})
|
||||
test.OK(t, err)
|
||||
test.Equals(t, n, 0)
|
||||
|
||||
n, err = limReader.Write(make([]byte, 300))
|
||||
test.OK(t, err)
|
||||
test.Equals(t, n, 300)
|
||||
}
|
||||
|
||||
type tracedReadCloser struct {
|
||||
io.Reader
|
||||
Closed bool
|
||||
|
Reference in New Issue
Block a user