mirror of
https://github.com/restic/restic.git
synced 2025-08-23 15:28:08 +00:00
azure: Switch to azblob sdk
This commit is contained in:

committed by
Michael Eischer

parent
4ba31df08f
commit
25648e2501
@@ -1,6 +1,7 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
@@ -18,14 +19,20 @@ import (
|
||||
"github.com/restic/restic/internal/errors"
|
||||
"github.com/restic/restic/internal/restic"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/storage"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
|
||||
azContainer "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
)
|
||||
|
||||
// Backend stores data on an azure endpoint.
|
||||
type Backend struct {
|
||||
accountName string
|
||||
container *storage.Container
|
||||
cfg Config
|
||||
container *azContainer.Client
|
||||
connections uint
|
||||
sem sema.Semaphore
|
||||
prefix string
|
||||
@@ -33,6 +40,7 @@ type Backend struct {
|
||||
layout.Layout
|
||||
}
|
||||
|
||||
const saveLargeSize = 256 * 1024 * 1024
|
||||
const defaultListMaxItems = 5000
|
||||
|
||||
// make sure that *Backend implements backend.Backend
|
||||
@@ -40,29 +48,47 @@ var _ restic.Backend = &Backend{}
|
||||
|
||||
func open(cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
debug.Log("open, config %#v", cfg)
|
||||
var client storage.Client
|
||||
var client *azContainer.Client
|
||||
var err error
|
||||
|
||||
url := fmt.Sprintf("https://%s.blob.core.windows.net/%s", cfg.AccountName, cfg.Container)
|
||||
opts := &azContainer.ClientOptions{
|
||||
ClientOptions: azcore.ClientOptions{
|
||||
Transport: http.DefaultClient,
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.AccountKey.String() != "" {
|
||||
// We have an account key value, find the BlobServiceClient
|
||||
// from with a BasicClient
|
||||
debug.Log(" - using account key")
|
||||
client, err = storage.NewBasicClient(cfg.AccountName, cfg.AccountKey.Unwrap())
|
||||
cred, err := azblob.NewSharedKeyCredential(cfg.AccountName, cfg.AccountKey.Unwrap())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "NewBasicClient")
|
||||
return nil, errors.Wrap(err, "NewSharedKeyCredential")
|
||||
}
|
||||
|
||||
client, err = azContainer.NewClientWithSharedKeyCredential(url, cred, opts)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "NewClientWithSharedKeyCredential")
|
||||
}
|
||||
} else if cfg.AccountSAS.String() != "" {
|
||||
// Get the client using the SAS Token as authentication, this
|
||||
// is longer winded than above because the SDK wants a URL for the Account
|
||||
// if your using a SAS token, and not just the account name
|
||||
// we (as per the SDK ) assume the default Azure portal.
|
||||
url := fmt.Sprintf("https://%s.blob.core.windows.net/", cfg.AccountName)
|
||||
// https://github.com/Azure/azure-storage-blob-go/issues/130
|
||||
debug.Log(" - using sas token")
|
||||
sas := cfg.AccountSAS.Unwrap()
|
||||
|
||||
// strip query sign prefix
|
||||
if sas[0] == '?' {
|
||||
sas = sas[1:]
|
||||
}
|
||||
client, err = storage.NewAccountSASClientFromEndpointToken(url, sas)
|
||||
|
||||
urlWithSAS := fmt.Sprintf("%s?%s", url, sas)
|
||||
|
||||
client, err = azContainer.NewClientWithNoCredential(urlWithSAS, opts)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "NewAccountSASClientFromEndpointToken")
|
||||
}
|
||||
@@ -70,21 +96,16 @@ func open(cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
return nil, errors.New("no azure authentication information found")
|
||||
}
|
||||
|
||||
client.HTTPClient = &http.Client{Transport: rt}
|
||||
|
||||
service := client.GetBlobService()
|
||||
|
||||
sem, err := sema.New(cfg.Connections)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
be := &Backend{
|
||||
container: service.GetContainerReference(cfg.Container),
|
||||
accountName: cfg.AccountName,
|
||||
container: client,
|
||||
cfg: cfg,
|
||||
connections: cfg.Connections,
|
||||
sem: sem,
|
||||
prefix: cfg.Prefix,
|
||||
Layout: &layout.DefaultLayout{
|
||||
Path: cfg.Prefix,
|
||||
Join: path.Join,
|
||||
@@ -96,26 +117,27 @@ func open(cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
}
|
||||
|
||||
// Open opens the Azure backend at specified container.
|
||||
func Open(cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
func Open(ctx context.Context, cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
return open(cfg, rt)
|
||||
}
|
||||
|
||||
// Create opens the Azure backend at specified container and creates the container if
|
||||
// it does not exist yet.
|
||||
func Create(cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
func Create(ctx context.Context, cfg Config, rt http.RoundTripper) (*Backend, error) {
|
||||
be, err := open(cfg, rt)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open")
|
||||
}
|
||||
|
||||
options := storage.CreateContainerOptions{
|
||||
Access: storage.ContainerAccessTypePrivate,
|
||||
}
|
||||
if err != nil && bloberror.HasCode(err, bloberror.ContainerNotFound) {
|
||||
_, err = be.container.Create(ctx, &azContainer.CreateOptions{})
|
||||
|
||||
_, err = be.container.CreateIfNotExists(&options)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "container.CreateIfNotExists")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "container.Create")
|
||||
}
|
||||
} else if err != nil {
|
||||
return be, err
|
||||
}
|
||||
|
||||
return be, nil
|
||||
@@ -129,8 +151,7 @@ func (be *Backend) SetListMaxItems(i int) {
|
||||
// IsNotExist returns true if the error is caused by a not existing file.
|
||||
func (be *Backend) IsNotExist(err error) bool {
|
||||
debug.Log("IsNotExist(%T, %#v)", err, err)
|
||||
var aerr storage.AzureStorageServiceError
|
||||
return errors.As(err, &aerr) && aerr.StatusCode == http.StatusNotFound
|
||||
return bloberror.HasCode(err, bloberror.BlobNotFound)
|
||||
}
|
||||
|
||||
// Join combines path components with slashes.
|
||||
@@ -144,7 +165,7 @@ func (be *Backend) Connections() uint {
|
||||
|
||||
// Location returns this backend's location (the container name).
|
||||
func (be *Backend) Location() string {
|
||||
return be.Join(be.container.Name, be.prefix)
|
||||
return be.Join(be.cfg.AccountName, be.cfg.Prefix)
|
||||
}
|
||||
|
||||
// Hasher may return a hash function for calculating a content hash for the backend
|
||||
@@ -162,16 +183,6 @@ func (be *Backend) Path() string {
|
||||
return be.prefix
|
||||
}
|
||||
|
||||
type azureAdapter struct {
|
||||
restic.RewindReader
|
||||
}
|
||||
|
||||
func (azureAdapter) Close() error { return nil }
|
||||
|
||||
func (a azureAdapter) Len() int {
|
||||
return int(a.Length())
|
||||
}
|
||||
|
||||
// Save stores data in the backend at the handle.
|
||||
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
|
||||
if err := h.Valid(); err != nil {
|
||||
@@ -184,41 +195,53 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindRe
|
||||
|
||||
be.sem.GetToken()
|
||||
|
||||
debug.Log("InsertObject(%v, %v)", be.container.Name, objName)
|
||||
debug.Log("InsertObject(%v, %v)", be.cfg.AccountName, objName)
|
||||
|
||||
var err error
|
||||
if rd.Length() < 256*1024*1024 {
|
||||
// wrap the reader so that net/http client cannot close the reader
|
||||
// CreateBlockBlobFromReader reads length from `Len()``
|
||||
dataReader := azureAdapter{rd}
|
||||
|
||||
if rd.Length() < saveLargeSize {
|
||||
// if it's smaller than 256miB, then just create the file directly from the reader
|
||||
ref := be.container.GetBlobReference(objName)
|
||||
ref.Properties.ContentMD5 = base64.StdEncoding.EncodeToString(rd.Hash())
|
||||
err = ref.CreateBlockBlobFromReader(dataReader, nil)
|
||||
err = be.saveSmall(ctx, objName, rd)
|
||||
} else {
|
||||
// otherwise use the more complicated method
|
||||
err = be.saveLarge(ctx, objName, rd)
|
||||
|
||||
}
|
||||
|
||||
be.sem.ReleaseToken()
|
||||
debug.Log("%v, err %#v", objName, err)
|
||||
|
||||
return errors.Wrap(err, "CreateBlockBlobFromReader")
|
||||
return err
|
||||
}
|
||||
|
||||
func (be *Backend) saveSmall(ctx context.Context, objName string, rd restic.RewindReader) error {
|
||||
blockBlobClient := be.container.NewBlockBlobClient(objName)
|
||||
|
||||
// upload it as a new "block", use the base64 hash for the ID
|
||||
id := base64.StdEncoding.EncodeToString(rd.Hash())
|
||||
|
||||
buf := make([]byte, rd.Length())
|
||||
_, err := io.ReadFull(rd, buf)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "ReadFull")
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(buf)
|
||||
_, err = blockBlobClient.StageBlock(ctx, id, streaming.NopCloser(reader), &blockblob.StageBlockOptions{
|
||||
TransactionalContentMD5: rd.Hash(),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "StageBlock")
|
||||
}
|
||||
|
||||
blocks := []string{id}
|
||||
_, err = blockBlobClient.CommitBlockList(ctx, blocks, &blockblob.CommitBlockListOptions{})
|
||||
return errors.Wrap(err, "CommitBlockList")
|
||||
}
|
||||
|
||||
func (be *Backend) saveLarge(ctx context.Context, objName string, rd restic.RewindReader) error {
|
||||
// create the file on the server
|
||||
file := be.container.GetBlobReference(objName)
|
||||
err := file.CreateBlockBlob(nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "CreateBlockBlob")
|
||||
}
|
||||
blockBlobClient := be.container.NewBlockBlobClient(objName)
|
||||
|
||||
// read the data, in 100 MiB chunks
|
||||
buf := make([]byte, 100*1024*1024)
|
||||
var blocks []storage.Block
|
||||
blocks := []string{}
|
||||
uploadedBytes := 0
|
||||
|
||||
for {
|
||||
@@ -226,6 +249,7 @@ func (be *Backend) saveLarge(ctx context.Context, objName string, rd restic.Rewi
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
err = nil
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
// end of file reached, no bytes have been read at all
|
||||
break
|
||||
@@ -241,16 +265,18 @@ func (be *Backend) saveLarge(ctx context.Context, objName string, rd restic.Rewi
|
||||
// upload it as a new "block", use the base64 hash for the ID
|
||||
h := md5.Sum(buf)
|
||||
id := base64.StdEncoding.EncodeToString(h[:])
|
||||
debug.Log("PutBlock %v with %d bytes", id, len(buf))
|
||||
err = file.PutBlock(id, buf, &storage.PutBlockOptions{ContentMD5: id})
|
||||
|
||||
reader := bytes.NewReader(buf)
|
||||
debug.Log("StageBlock %v with %d bytes", id, len(buf))
|
||||
_, err = blockBlobClient.StageBlock(ctx, id, streaming.NopCloser(reader), &blockblob.StageBlockOptions{
|
||||
TransactionalContentMD5: h[:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "PutBlock")
|
||||
return errors.Wrap(err, "StageBlock")
|
||||
}
|
||||
|
||||
blocks = append(blocks, storage.Block{
|
||||
ID: id,
|
||||
Status: "Uncommitted",
|
||||
})
|
||||
blocks = append(blocks, id)
|
||||
}
|
||||
|
||||
// sanity check
|
||||
@@ -258,10 +284,10 @@ func (be *Backend) saveLarge(ctx context.Context, objName string, rd restic.Rewi
|
||||
return errors.Errorf("wrote %d bytes instead of the expected %d bytes", uploadedBytes, rd.Length())
|
||||
}
|
||||
|
||||
_, err := blockBlobClient.CommitBlockList(ctx, blocks, &blockblob.CommitBlockListOptions{})
|
||||
|
||||
debug.Log("uploaded %d parts: %v", len(blocks), blocks)
|
||||
err = file.PutBlockList(blocks, nil)
|
||||
debug.Log("PutBlockList returned %v", err)
|
||||
return errors.Wrap(err, "PutBlockList")
|
||||
return errors.Wrap(err, "CommitBlockList")
|
||||
}
|
||||
|
||||
// Load runs fn with a reader that yields the contents of the file at h at the
|
||||
@@ -285,26 +311,22 @@ func (be *Backend) openReader(ctx context.Context, h restic.Handle, length int,
|
||||
}
|
||||
|
||||
objName := be.Filename(h)
|
||||
blob := be.container.GetBlobReference(objName)
|
||||
|
||||
start := uint64(offset)
|
||||
var end uint64
|
||||
|
||||
if length > 0 {
|
||||
end = uint64(offset + int64(length) - 1)
|
||||
} else {
|
||||
end = 0
|
||||
}
|
||||
blockBlobClient := be.container.NewBlobClient(objName)
|
||||
|
||||
be.sem.GetToken()
|
||||
resp, err := blockBlobClient.DownloadStream(ctx, &blob.DownloadStreamOptions{
|
||||
Range: azblob.HTTPRange{
|
||||
Offset: offset,
|
||||
Count: int64(length),
|
||||
},
|
||||
})
|
||||
|
||||
rd, err := blob.GetRange(&storage.GetBlobRangeOptions{Range: &storage.BlobRange{Start: start, End: end}})
|
||||
if err != nil {
|
||||
be.sem.ReleaseToken()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return be.sem.ReleaseTokenOnClose(rd, nil), err
|
||||
return be.sem.ReleaseTokenOnClose(resp.Body, nil), err
|
||||
}
|
||||
|
||||
// Stat returns information about a blob.
|
||||
@@ -312,10 +334,10 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo,
|
||||
debug.Log("%v", h)
|
||||
|
||||
objName := be.Filename(h)
|
||||
blob := be.container.GetBlobReference(objName)
|
||||
blobClient := be.container.NewBlobClient(objName)
|
||||
|
||||
be.sem.GetToken()
|
||||
err := blob.GetProperties(nil)
|
||||
props, err := blobClient.GetProperties(ctx, nil)
|
||||
be.sem.ReleaseToken()
|
||||
|
||||
if err != nil {
|
||||
@@ -324,7 +346,7 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo,
|
||||
}
|
||||
|
||||
fi := restic.FileInfo{
|
||||
Size: int64(blob.Properties.ContentLength),
|
||||
Size: *props.ContentLength,
|
||||
Name: h.Name,
|
||||
}
|
||||
return fi, nil
|
||||
@@ -333,12 +355,18 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo,
|
||||
// Remove removes the blob with the given name and type.
|
||||
func (be *Backend) Remove(ctx context.Context, h restic.Handle) error {
|
||||
objName := be.Filename(h)
|
||||
blob := be.container.NewBlobClient(objName)
|
||||
|
||||
be.sem.GetToken()
|
||||
_, err := be.container.GetBlobReference(objName).DeleteIfExists(nil)
|
||||
_, err := blob.Delete(ctx, &azblob.DeleteBlobOptions{})
|
||||
be.sem.ReleaseToken()
|
||||
|
||||
debug.Log("Remove(%v) at %v -> err %v", h, objName, err)
|
||||
|
||||
if bloberror.HasCode(err, bloberror.BlobNotFound) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Wrap(err, "client.RemoveObject")
|
||||
}
|
||||
|
||||
@@ -354,31 +382,34 @@ func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.F
|
||||
prefix += "/"
|
||||
}
|
||||
|
||||
params := storage.ListBlobsParameters{
|
||||
MaxResults: uint(be.listMaxItems),
|
||||
Prefix: prefix,
|
||||
}
|
||||
max := int32(be.listMaxItems)
|
||||
|
||||
for {
|
||||
opts := &azContainer.ListBlobsFlatOptions{
|
||||
MaxResults: &max,
|
||||
Prefix: &prefix,
|
||||
}
|
||||
lister := be.container.NewListBlobsFlatPager(opts)
|
||||
|
||||
for lister.More() {
|
||||
be.sem.GetToken()
|
||||
obj, err := be.container.ListBlobs(params)
|
||||
resp, err := lister.NextPage(ctx)
|
||||
be.sem.ReleaseToken()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
debug.Log("got %v objects", len(obj.Blobs))
|
||||
debug.Log("got %v objects", len(resp.Segment.BlobItems))
|
||||
|
||||
for _, item := range obj.Blobs {
|
||||
m := strings.TrimPrefix(item.Name, prefix)
|
||||
for _, item := range resp.Segment.BlobItems {
|
||||
m := strings.TrimPrefix(*item.Name, prefix)
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fi := restic.FileInfo{
|
||||
Name: path.Base(m),
|
||||
Size: item.Properties.ContentLength,
|
||||
Size: *item.Properties.ContentLength,
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
@@ -395,11 +426,6 @@ func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.F
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if obj.NextMarker == "" {
|
||||
break
|
||||
}
|
||||
params.Marker = obj.NextMarker
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
|
@@ -46,7 +46,8 @@ func newAzureTestSuite(t testing.TB) *test.Suite {
|
||||
Create: func(config interface{}) (restic.Backend, error) {
|
||||
cfg := config.(azure.Config)
|
||||
|
||||
be, err := azure.Create(cfg, tr)
|
||||
ctx := context.TODO()
|
||||
be, err := azure.Create(ctx, cfg, tr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,15 +67,15 @@ func newAzureTestSuite(t testing.TB) *test.Suite {
|
||||
// OpenFn is a function that opens a previously created temporary repository.
|
||||
Open: func(config interface{}) (restic.Backend, error) {
|
||||
cfg := config.(azure.Config)
|
||||
|
||||
return azure.Open(cfg, tr)
|
||||
ctx := context.TODO()
|
||||
return azure.Open(ctx, cfg, tr)
|
||||
},
|
||||
|
||||
// CleanupFn removes data created during the tests.
|
||||
Cleanup: func(config interface{}) error {
|
||||
cfg := config.(azure.Config)
|
||||
|
||||
be, err := azure.Open(cfg, tr)
|
||||
ctx := context.TODO()
|
||||
be, err := azure.Open(ctx, cfg, tr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func TestUploadLargeFile(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
be, err := azure.Create(cfg, tr)
|
||||
be, err := azure.Create(ctx, cfg, tr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user