Merge branch 'master' into windows-securitydesc

This commit is contained in:
Aneesh N
2024-04-29 14:40:34 -06:00
committed by GitHub
162 changed files with 4553 additions and 2518 deletions

View File

@@ -3,41 +3,108 @@ package fs
import (
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/options"
)
// ErrorHandler is used to report errors via callback
type ErrorHandler func(item string, err error) error
// VSSConfig holds extended options of windows volume shadow copy service.
type VSSConfig struct {
ExcludeAllMountPoints bool `option:"exclude-all-mount-points" help:"exclude mountpoints from snapshotting on all volumes"`
ExcludeVolumes string `option:"exclude-volumes" help:"semicolon separated list of volumes to exclude from snapshotting (ex. 'c:\\;e:\\mnt;\\\\?\\Volume{...}')"`
Timeout time.Duration `option:"timeout" help:"time that the VSS can spend creating snapshot before timing out"`
Provider string `option:"provider" help:"VSS provider identifier which will be used for snapshotting"`
}
func init() {
if runtime.GOOS == "windows" {
options.Register("vss", VSSConfig{})
}
}
// NewVSSConfig returns a new VSSConfig with the default values filled in.
func NewVSSConfig() VSSConfig {
return VSSConfig{
Timeout: time.Second * 120,
}
}
// ParseVSSConfig parses a VSS extended options to VSSConfig struct.
func ParseVSSConfig(o options.Options) (VSSConfig, error) {
cfg := NewVSSConfig()
o = o.Extract("vss")
if err := o.Apply("vss", &cfg); err != nil {
return VSSConfig{}, err
}
return cfg, nil
}
// ErrorHandler is used to report errors via callback.
type ErrorHandler func(item string, err error)
// MessageHandler is used to report errors/messages via callbacks.
type MessageHandler func(msg string, args ...interface{})
// VolumeFilter is used to filter volumes by it's mount point or GUID path.
type VolumeFilter func(volume string) bool
// LocalVss is a wrapper around the local file system which uses windows volume
// shadow copy service (VSS) in a transparent way.
type LocalVss struct {
FS
snapshots map[string]VssSnapshot
failedSnapshots map[string]struct{}
mutex sync.RWMutex
msgError ErrorHandler
msgMessage MessageHandler
snapshots map[string]VssSnapshot
failedSnapshots map[string]struct{}
mutex sync.RWMutex
msgError ErrorHandler
msgMessage MessageHandler
excludeAllMountPoints bool
excludeVolumes map[string]struct{}
timeout time.Duration
provider string
}
// statically ensure that LocalVss implements FS.
var _ FS = &LocalVss{}
// parseMountPoints try to convert semicolon separated list of mount points
// to map of lowercased volume GUID pathes. Mountpoints already in volume
// GUID path format will be validated and normalized.
func parseMountPoints(list string, msgError ErrorHandler) (volumes map[string]struct{}) {
if list == "" {
return
}
for _, s := range strings.Split(list, ";") {
if v, err := GetVolumeNameForVolumeMountPoint(s); err != nil {
msgError(s, errors.Errorf("failed to parse vss.exclude-volumes [%s]: %s", s, err))
} else {
if volumes == nil {
volumes = make(map[string]struct{})
}
volumes[strings.ToLower(v)] = struct{}{}
}
}
return
}
// NewLocalVss creates a new wrapper around the windows filesystem using volume
// shadow copy service to access locked files.
func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler) *LocalVss {
func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler, cfg VSSConfig) *LocalVss {
return &LocalVss{
FS: Local{},
snapshots: make(map[string]VssSnapshot),
failedSnapshots: make(map[string]struct{}),
msgError: msgError,
msgMessage: msgMessage,
FS: Local{},
snapshots: make(map[string]VssSnapshot),
failedSnapshots: make(map[string]struct{}),
msgError: msgError,
msgMessage: msgMessage,
excludeAllMountPoints: cfg.ExcludeAllMountPoints,
excludeVolumes: parseMountPoints(cfg.ExcludeVolumes, msgError),
timeout: cfg.Timeout,
provider: cfg.Provider,
}
}
@@ -50,7 +117,7 @@ func (fs *LocalVss) DeleteSnapshots() {
for volumeName, snapshot := range fs.snapshots {
if err := snapshot.Delete(); err != nil {
_ = fs.msgError(volumeName, errors.Errorf("failed to delete VSS snapshot: %s", err))
fs.msgError(volumeName, errors.Errorf("failed to delete VSS snapshot: %s", err))
activeSnapshots[volumeName] = snapshot
}
}
@@ -78,12 +145,27 @@ func (fs *LocalVss) Lstat(name string) (os.FileInfo, error) {
return os.Lstat(fs.snapshotPath(name))
}
// isMountPointIncluded is true if given mountpoint included by user.
func (fs *LocalVss) isMountPointIncluded(mountPoint string) bool {
if fs.excludeVolumes == nil {
return true
}
volume, err := GetVolumeNameForVolumeMountPoint(mountPoint)
if err != nil {
fs.msgError(mountPoint, errors.Errorf("failed to get volume from mount point [%s]: %s", mountPoint, err))
return true
}
_, ok := fs.excludeVolumes[strings.ToLower(volume)]
return !ok
}
// snapshotPath returns the path inside a VSS snapshots if it already exists.
// If the path is not yet available as a snapshot, a snapshot is created.
// If creation of a snapshot fails the file's original path is returned as
// a fallback.
func (fs *LocalVss) snapshotPath(path string) string {
fixPath := fixpath(path)
if strings.HasPrefix(fixPath, `\\?\UNC\`) {
@@ -114,23 +196,36 @@ func (fs *LocalVss) snapshotPath(path string) string {
if !snapshotExists && !snapshotFailed {
vssVolume := volumeNameLower + string(filepath.Separator)
fs.msgMessage("creating VSS snapshot for [%s]\n", vssVolume)
if snapshot, err := NewVssSnapshot(vssVolume, 120, fs.msgError); err != nil {
_ = fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s",
vssVolume, err))
if !fs.isMountPointIncluded(vssVolume) {
fs.msgMessage("snapshots for [%s] excluded by user\n", vssVolume)
fs.failedSnapshots[volumeNameLower] = struct{}{}
} else {
fs.snapshots[volumeNameLower] = snapshot
fs.msgMessage("successfully created snapshot for [%s]\n", vssVolume)
if len(snapshot.mountPointInfo) > 0 {
fs.msgMessage("mountpoints in snapshot volume [%s]:\n", vssVolume)
for mp, mpInfo := range snapshot.mountPointInfo {
info := ""
if !mpInfo.IsSnapshotted() {
info = " (not snapshotted)"
fs.msgMessage("creating VSS snapshot for [%s]\n", vssVolume)
var includeVolume VolumeFilter
if !fs.excludeAllMountPoints {
includeVolume = func(volume string) bool {
return fs.isMountPointIncluded(volume)
}
}
if snapshot, err := NewVssSnapshot(fs.provider, vssVolume, fs.timeout, includeVolume, fs.msgError); err != nil {
fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s",
vssVolume, err))
fs.failedSnapshots[volumeNameLower] = struct{}{}
} else {
fs.snapshots[volumeNameLower] = snapshot
fs.msgMessage("successfully created snapshot for [%s]\n", vssVolume)
if len(snapshot.mountPointInfo) > 0 {
fs.msgMessage("mountpoints in snapshot volume [%s]:\n", vssVolume)
for mp, mpInfo := range snapshot.mountPointInfo {
info := ""
if !mpInfo.IsSnapshotted() {
info = " (not snapshotted)"
}
fs.msgMessage(" - %s%s\n", mp, info)
}
fs.msgMessage(" - %s%s\n", mp, info)
}
}
}
@@ -173,9 +268,8 @@ func (fs *LocalVss) snapshotPath(path string) string {
snapshotPath = fs.Join(snapshot.GetSnapshotDeviceObject(),
strings.TrimPrefix(fixPath, volumeName))
if snapshotPath == snapshot.GetSnapshotDeviceObject() {
snapshotPath = snapshotPath + string(filepath.Separator)
snapshotPath += string(filepath.Separator)
}
} else {
// no snapshot is available for the requested path:
// -> try to backup without a snapshot

View File

@@ -0,0 +1,285 @@
// +build windows
package fs
import (
"fmt"
"regexp"
"strings"
"testing"
"time"
ole "github.com/go-ole/go-ole"
"github.com/restic/restic/internal/options"
)
func matchStrings(ptrs []string, strs []string) bool {
if len(ptrs) != len(strs) {
return false
}
for i, p := range ptrs {
if p == "" {
return false
}
matched, err := regexp.MatchString(p, strs[i])
if err != nil {
panic(err)
}
if !matched {
return false
}
}
return true
}
func matchMap(strs []string, m map[string]struct{}) bool {
if len(strs) != len(m) {
return false
}
for _, s := range strs {
if _, ok := m[s]; !ok {
return false
}
}
return true
}
func TestVSSConfig(t *testing.T) {
type config struct {
excludeAllMountPoints bool
timeout time.Duration
provider string
}
setTests := []struct {
input options.Options
output config
}{
{
options.Options{
"vss.timeout": "6h38m42s",
"vss.provider": "Ms",
},
config{
timeout: 23922000000000,
provider: "Ms",
},
},
{
options.Options{
"vss.exclude-all-mount-points": "t",
"vss.provider": "{b5946137-7b9f-4925-af80-51abd60b20d5}",
},
config{
excludeAllMountPoints: true,
timeout: 120000000000,
provider: "{b5946137-7b9f-4925-af80-51abd60b20d5}",
},
},
{
options.Options{
"vss.exclude-all-mount-points": "0",
"vss.exclude-volumes": "",
"vss.timeout": "120s",
"vss.provider": "Microsoft Software Shadow Copy provider 1.0",
},
config{
timeout: 120000000000,
provider: "Microsoft Software Shadow Copy provider 1.0",
},
},
}
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
cfg, err := ParseVSSConfig(test.input)
if err != nil {
t.Fatal(err)
}
errorHandler := func(item string, err error) {
t.Fatalf("unexpected error (%v)", err)
}
messageHandler := func(msg string, args ...interface{}) {
t.Fatalf("unexpected message (%s)", fmt.Sprintf(msg, args))
}
dst := NewLocalVss(errorHandler, messageHandler, cfg)
if dst.excludeAllMountPoints != test.output.excludeAllMountPoints ||
dst.excludeVolumes != nil || dst.timeout != test.output.timeout ||
dst.provider != test.output.provider {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v", test.output, dst)
}
})
}
}
func TestParseMountPoints(t *testing.T) {
volumeMatch := regexp.MustCompile(`^\\\\\?\\Volume\{[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}\}\\$`)
// It's not a good idea to test functions based on GetVolumeNameForVolumeMountPoint by calling
// GetVolumeNameForVolumeMountPoint itself, but we have restricted test environment:
// cannot manage volumes and can only be sure that the mount point C:\ exists
sysVolume, err := GetVolumeNameForVolumeMountPoint("C:")
if err != nil {
t.Fatal(err)
}
// We don't know a valid volume GUID path for c:\, but we'll at least check its format
if !volumeMatch.MatchString(sysVolume) {
t.Fatalf("invalid volume GUID path: %s", sysVolume)
}
// Changing the case and removing trailing backslash allows tests
// the equality of different ways of writing a volume name
sysVolumeMutated := strings.ToUpper(sysVolume[:len(sysVolume)-1])
sysVolumeMatch := strings.ToLower(sysVolume)
type check struct {
volume string
result bool
}
setTests := []struct {
input options.Options
output []string
checks []check
errors []string
}{
{
options.Options{
"vss.exclude-volumes": `c:;c:\;` + sysVolume + `;` + sysVolumeMutated,
},
[]string{
sysVolumeMatch,
},
[]check{
{`c:\`, false},
{`c:`, false},
{sysVolume, false},
{sysVolumeMutated, false},
},
[]string{},
},
{
options.Options{
"vss.exclude-volumes": `z:\nonexistent;c:;c:\windows\;\\?\Volume{39b9cac2-bcdb-4d51-97c8-0d0677d607fb}\`,
},
[]string{
sysVolumeMatch,
},
[]check{
{`c:\windows\`, true},
{`\\?\Volume{39b9cac2-bcdb-4d51-97c8-0d0677d607fb}\`, true},
{`c:`, false},
{``, true},
},
[]string{
`failed to parse vss\.exclude-volumes \[z:\\nonexistent\]:.*`,
`failed to parse vss\.exclude-volumes \[c:\\windows\\\]:.*`,
`failed to parse vss\.exclude-volumes \[\\\\\?\\Volume\{39b9cac2-bcdb-4d51-97c8-0d0677d607fb\}\\\]:.*`,
`failed to get volume from mount point \[c:\\windows\\\]:.*`,
`failed to get volume from mount point \[\\\\\?\\Volume\{39b9cac2-bcdb-4d51-97c8-0d0677d607fb\}\\\]:.*`,
`failed to get volume from mount point \[\]:.*`,
},
},
}
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
cfg, err := ParseVSSConfig(test.input)
if err != nil {
t.Fatal(err)
}
var log []string
errorHandler := func(item string, err error) {
log = append(log, strings.TrimSpace(err.Error()))
}
messageHandler := func(msg string, args ...interface{}) {
t.Fatalf("unexpected message (%s)", fmt.Sprintf(msg, args))
}
dst := NewLocalVss(errorHandler, messageHandler, cfg)
if !matchMap(test.output, dst.excludeVolumes) {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v",
test.output, dst.excludeVolumes)
}
for _, c := range test.checks {
if dst.isMountPointIncluded(c.volume) != c.result {
t.Fatalf(`wrong check: isMountPointIncluded("%s") != %v`, c.volume, c.result)
}
}
if !matchStrings(test.errors, log) {
t.Fatalf("wrong log, want:\n %#v\ngot:\n %#v", test.errors, log)
}
})
}
}
func TestParseProvider(t *testing.T) {
msProvider := ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}")
setTests := []struct {
provider string
id *ole.GUID
result string
}{
{
"",
ole.IID_NULL,
"",
},
{
"mS",
msProvider,
"",
},
{
"{B5946137-7b9f-4925-Af80-51abD60b20d5}",
msProvider,
"",
},
{
"Microsoft Software Shadow Copy provider 1.0",
msProvider,
"",
},
{
"{04560982-3d7d-4bbc-84f7-0712f833a28f}",
nil,
`invalid VSS provider "{04560982-3d7d-4bbc-84f7-0712f833a28f}"`,
},
{
"non-existent provider",
nil,
`invalid VSS provider "non-existent provider"`,
},
}
_ = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
id, err := getProviderID(test.provider)
if err != nil && id != nil {
t.Fatalf("err!=nil but id=%v", id)
}
if test.result != "" || err != nil {
var result string
if err != nil {
result = err.Error()
}
if test.result != result || test.result == "" {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v", test.result, result)
}
} else if !ole.IsEqualGUID(id, test.id) {
t.Fatalf("wrong id, want:\n %s\ngot:\n %s", test.id.String(), id.String())
}
})
}
}

View File

@@ -5,11 +5,11 @@ import (
"path/filepath"
"testing"
restictest "github.com/restic/restic/internal/test"
rtest "github.com/restic/restic/internal/test"
)
func TestExtendedStat(t *testing.T) {
tempdir := restictest.TempDir(t)
tempdir := rtest.TempDir(t)
filename := filepath.Join(tempdir, "file")
err := os.WriteFile(filename, []byte("foobar"), 0640)
if err != nil {

View File

@@ -4,6 +4,8 @@
package fs
import (
"time"
"github.com/restic/restic/internal/errors"
)
@@ -31,10 +33,16 @@ func HasSufficientPrivilegesForVSS() error {
return errors.New("VSS snapshots are only supported on windows")
}
// GetVolumeNameForVolumeMountPoint add trailing backslash to input parameter
// and calls the equivalent windows api.
func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
return mountPoint, nil
}
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
// finish within the timeout an error is returned.
func NewVssSnapshot(
_ string, _ uint, _ ErrorHandler) (VssSnapshot, error) {
func NewVssSnapshot(_ string,
_ string, _ time.Duration, _ VolumeFilter, _ ErrorHandler) (VssSnapshot, error) {
return VssSnapshot{}, errors.New("VSS snapshots are only supported on windows")
}

View File

@@ -5,10 +5,12 @@ package fs
import (
"fmt"
"math"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
ole "github.com/go-ole/go-ole"
@@ -20,8 +22,10 @@ import (
type HRESULT uint
// HRESULT constant values necessary for using VSS api.
//nolint:golint
const (
S_OK HRESULT = 0x00000000
S_FALSE HRESULT = 0x00000001
E_ACCESSDENIED HRESULT = 0x80070005
E_OUTOFMEMORY HRESULT = 0x8007000E
E_INVALIDARG HRESULT = 0x80070057
@@ -190,7 +194,7 @@ func (e *vssError) Error() string {
return fmt.Sprintf("VSS error: %s: %s (%#x)", e.text, e.hresult.Str(), e.hresult)
}
// VssError encapsulates errors returned from calling VSS api.
// vssTextError encapsulates errors returned from calling VSS api.
type vssTextError struct {
text string
}
@@ -255,6 +259,7 @@ type IVssBackupComponents struct {
}
// IVssBackupComponentsVTable is the vtable for IVssBackupComponents.
// nolint:structcheck
type IVssBackupComponentsVTable struct {
ole.IUnknownVtbl
getWriterComponentsCount uintptr
@@ -364,7 +369,7 @@ func (vss *IVssBackupComponents) convertToVSSAsync(
}
// IsVolumeSupported calls the equivalent VSS api.
func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, error) {
func (vss *IVssBackupComponents) IsVolumeSupported(providerID *ole.GUID, volumeName string) (bool, error) {
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
if err != nil {
panic(err)
@@ -374,7 +379,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL))
id := (*[4]uintptr)(unsafe.Pointer(providerID))
result, _, _ = syscall.Syscall9(vss.getVTable().isVolumeSupported, 7,
uintptr(unsafe.Pointer(vss)), id[0], id[1], id[2], id[3],
@@ -382,7 +387,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
0)
} else {
result, _, _ = syscall.Syscall6(vss.getVTable().isVolumeSupported, 4,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(ole.IID_NULL)),
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(providerID)),
uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(&isSupportedRaw)), 0,
0)
}
@@ -408,24 +413,24 @@ func (vss *IVssBackupComponents) StartSnapshotSet() (ole.GUID, error) {
}
// AddToSnapshotSet calls the equivalent VSS api.
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot *ole.GUID) error {
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, providerID *ole.GUID, idSnapshot *ole.GUID) error {
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
if err != nil {
panic(err)
}
var result uintptr = 0
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL))
id := (*[4]uintptr)(unsafe.Pointer(providerID))
result, _, _ = syscall.Syscall9(vss.getVTable().addToSnapshotSet, 7,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), id[0], id[1],
id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
id[0], id[1], id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
} else {
result, _, _ = syscall.Syscall6(vss.getVTable().addToSnapshotSet, 4,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
uintptr(unsafe.Pointer(ole.IID_NULL)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
uintptr(unsafe.Pointer(providerID)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
}
return newVssErrorIfResultNotOK("AddToSnapshotSet() failed", HRESULT(result))
@@ -478,9 +483,9 @@ func (vss *IVssBackupComponents) DoSnapshotSet() (*IVSSAsync, error) {
// DeleteSnapshots calls the equivalent VSS api.
func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ole.GUID, error) {
var deletedSnapshots int32 = 0
var deletedSnapshots int32
var nondeletedSnapshotID ole.GUID
var result uintptr = 0
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(&snapshotID))
@@ -504,7 +509,7 @@ func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ol
// GetSnapshotProperties calls the equivalent VSS api.
func (vss *IVssBackupComponents) GetSnapshotProperties(snapshotID ole.GUID,
properties *VssSnapshotProperties) error {
var result uintptr = 0
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(&snapshotID))
@@ -527,8 +532,8 @@ func vssFreeSnapshotProperties(properties *VssSnapshotProperties) error {
if err != nil {
return err
}
proc.Call(uintptr(unsafe.Pointer(properties)))
// this function always succeeds and returns no value
_, _, _ = proc.Call(uintptr(unsafe.Pointer(properties)))
return nil
}
@@ -543,6 +548,7 @@ func (vss *IVssBackupComponents) BackupComplete() (*IVSSAsync, error) {
}
// VssSnapshotProperties defines the properties of a VSS snapshot as part of the VSS api.
// nolint:structcheck
type VssSnapshotProperties struct {
snapshotID ole.GUID
snapshotSetID ole.GUID
@@ -559,6 +565,24 @@ type VssSnapshotProperties struct {
status uint
}
// VssProviderProperties defines the properties of a VSS provider as part of the VSS api.
// nolint:structcheck
type VssProviderProperties struct {
providerID ole.GUID
providerName *uint16
providerType uint32
providerVersion *uint16
providerVersionID ole.GUID
classID ole.GUID
}
func vssFreeProviderProperties(p *VssProviderProperties) {
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerName)))
p.providerName = nil
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerVersion)))
p.providerVersion = nil
}
// GetSnapshotDeviceObject returns root path to access the snapshot files
// and folders.
func (p *VssSnapshotProperties) GetSnapshotDeviceObject() string {
@@ -617,8 +641,13 @@ func (vssAsync *IVSSAsync) QueryStatus() (HRESULT, uint32) {
// WaitUntilAsyncFinished waits until either the async call is finished or
// the given timeout is reached.
func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(millis uint32) error {
hresult := vssAsync.Wait(millis)
func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(timeout time.Duration) error {
const maxTimeout = math.MaxInt32 * time.Millisecond
if timeout > maxTimeout {
timeout = maxTimeout
}
hresult := vssAsync.Wait(uint32(timeout.Milliseconds()))
err := newVssErrorIfResultNotOK("Wait() failed", hresult)
if err != nil {
vssAsync.Cancel()
@@ -651,6 +680,75 @@ func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(millis uint32) error {
return nil
}
// UIID_IVSS_ADMIN defines the GUID of IVSSAdmin.
var (
UIID_IVSS_ADMIN = ole.NewGUID("{77ED5996-2F63-11d3-8A39-00C04F72D8E3}")
CLSID_VSS_COORDINATOR = ole.NewGUID("{E579AB5F-1CC4-44b4-BED9-DE0991FF0623}")
)
// IVSSAdmin VSS api interface.
type IVSSAdmin struct {
ole.IUnknown
}
// IVSSAdminVTable is the vtable for IVSSAdmin.
// nolint:structcheck
type IVSSAdminVTable struct {
ole.IUnknownVtbl
registerProvider uintptr
unregisterProvider uintptr
queryProviders uintptr
abortAllSnapshotsInProgress uintptr
}
// getVTable returns the vtable for IVSSAdmin.
func (vssAdmin *IVSSAdmin) getVTable() *IVSSAdminVTable {
return (*IVSSAdminVTable)(unsafe.Pointer(vssAdmin.RawVTable))
}
// QueryProviders calls the equivalent VSS api.
func (vssAdmin *IVSSAdmin) QueryProviders() (*IVssEnumObject, error) {
var enum *IVssEnumObject
result, _, _ := syscall.Syscall(vssAdmin.getVTable().queryProviders, 2,
uintptr(unsafe.Pointer(vssAdmin)), uintptr(unsafe.Pointer(&enum)), 0)
return enum, newVssErrorIfResultNotOK("QueryProviders() failed", HRESULT(result))
}
// IVssEnumObject VSS api interface.
type IVssEnumObject struct {
ole.IUnknown
}
// IVssEnumObjectVTable is the vtable for IVssEnumObject.
// nolint:structcheck
type IVssEnumObjectVTable struct {
ole.IUnknownVtbl
next uintptr
skip uintptr
reset uintptr
clone uintptr
}
// getVTable returns the vtable for IVssEnumObject.
func (vssEnum *IVssEnumObject) getVTable() *IVssEnumObjectVTable {
return (*IVssEnumObjectVTable)(unsafe.Pointer(vssEnum.RawVTable))
}
// Next calls the equivalent VSS api.
func (vssEnum *IVssEnumObject) Next(count uint, props unsafe.Pointer) (uint, error) {
var fetched uint32
result, _, _ := syscall.Syscall6(vssEnum.getVTable().next, 4,
uintptr(unsafe.Pointer(vssEnum)), uintptr(count), uintptr(props),
uintptr(unsafe.Pointer(&fetched)), 0, 0)
if HRESULT(result) == S_FALSE {
return uint(fetched), nil
}
return uint(fetched), newVssErrorIfResultNotOK("Next() failed", HRESULT(result))
}
// MountPoint wraps all information of a snapshot of a mountpoint on a volume.
type MountPoint struct {
isSnapshotted bool
@@ -677,7 +775,7 @@ type VssSnapshot struct {
snapshotProperties VssSnapshotProperties
snapshotDeviceObject string
mountPointInfo map[string]MountPoint
timeoutInMillis uint32
timeout time.Duration
}
// GetSnapshotDeviceObject returns root path to access the snapshot files
@@ -694,7 +792,12 @@ func initializeVssCOMInterface() (*ole.IUnknown, error) {
}
// ensure COM is initialized before use
ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
// CoInitializeEx returns S_FALSE if COM is already initialized
if oleErr, ok := err.(*ole.OleError); !ok || HRESULT(oleErr.Code()) != S_FALSE {
return nil, err
}
}
var oleIUnknown *ole.IUnknown
result, _, _ := vssInstance.Call(uintptr(unsafe.Pointer(&oleIUnknown)))
@@ -727,12 +830,34 @@ func HasSufficientPrivilegesForVSS() error {
return err
}
// GetVolumeNameForVolumeMountPoint add trailing backslash to input parameter
// and calls the equivalent windows api.
func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
if mountPoint != "" && mountPoint[len(mountPoint)-1] != filepath.Separator {
mountPoint += string(filepath.Separator)
}
mountPointPointer, err := syscall.UTF16PtrFromString(mountPoint)
if err != nil {
return mountPoint, err
}
// A reasonable size for the buffer to accommodate the largest possible
// volume GUID path is 50 characters.
volumeNameBuffer := make([]uint16, 50)
if err := windows.GetVolumeNameForVolumeMountPoint(
mountPointPointer, &volumeNameBuffer[0], 50); err != nil {
return mountPoint, err
}
return syscall.UTF16ToString(volumeNameBuffer), nil
}
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
// finish within the timeout an error is returned.
func NewVssSnapshot(
volume string, timeoutInSeconds uint, msgError ErrorHandler) (VssSnapshot, error) {
func NewVssSnapshot(provider string,
volume string, timeout time.Duration, filter VolumeFilter, msgError ErrorHandler) (VssSnapshot, error) {
is64Bit, err := isRunningOn64BitWindows()
if err != nil {
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"Failed to detect windows architecture: %s", err.Error()))
@@ -744,7 +869,7 @@ func NewVssSnapshot(
runtime.GOARCH))
}
timeoutInMillis := uint32(timeoutInSeconds * 1000)
deadline := time.Now().Add(timeout)
oleIUnknown, err := initializeVssCOMInterface()
if oleIUnknown != nil {
@@ -778,6 +903,12 @@ func NewVssSnapshot(
iVssBackupComponents := (*IVssBackupComponents)(unsafe.Pointer(comInterface))
providerID, err := getProviderID(provider)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
if err := iVssBackupComponents.InitializeForBackup(); err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
@@ -796,13 +927,13 @@ func NewVssSnapshot(
}
err = callAsyncFunctionAndWait(iVssBackupComponents.GatherWriterMetadata,
"GatherWriterMetadata", timeoutInMillis)
"GatherWriterMetadata", deadline)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
if isSupported, err := iVssBackupComponents.IsVolumeSupported(volume); err != nil {
if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, volume); err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
} else if !isSupported {
@@ -817,44 +948,53 @@ func NewVssSnapshot(
return VssSnapshot{}, err
}
if err := iVssBackupComponents.AddToSnapshotSet(volume, &snapshotSetID); err != nil {
if err := iVssBackupComponents.AddToSnapshotSet(volume, providerID, &snapshotSetID); err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
mountPoints, err := enumerateMountedFolders(volume)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"failed to enumerate mount points for volume %s: %s", volume, err))
}
mountPointInfo := make(map[string]MountPoint)
for _, mountPoint := range mountPoints {
// ensure every mountpoint is available even without a valid
// snapshot because we need to consider this when backing up files
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: false}
if isSupported, err := iVssBackupComponents.IsVolumeSupported(mountPoint); err != nil {
continue
} else if !isSupported {
continue
}
var mountPointSnapshotSetID ole.GUID
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, &mountPointSnapshotSetID)
// if filter==nil just don't process mount points for this volume at all
if filter != nil {
mountPoints, err := enumerateMountedFolders(volume)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"failed to enumerate mount points for volume %s: %s", volume, err))
}
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: true,
snapshotSetID: mountPointSnapshotSetID}
for _, mountPoint := range mountPoints {
// ensure every mountpoint is available even without a valid
// snapshot because we need to consider this when backing up files
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: false}
if !filter(mountPoint) {
continue
} else if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, mountPoint); err != nil {
continue
} else if !isSupported {
continue
}
var mountPointSnapshotSetID ole.GUID
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, providerID, &mountPointSnapshotSetID)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
mountPointInfo[mountPoint] = MountPoint{
isSnapshotted: true,
snapshotSetID: mountPointSnapshotSetID,
}
}
}
err = callAsyncFunctionAndWait(iVssBackupComponents.PrepareForBackup, "PrepareForBackup",
timeoutInMillis)
deadline)
if err != nil {
// After calling PrepareForBackup one needs to call AbortBackup() before releasing the VSS
// instance for proper cleanup.
@@ -865,9 +1005,9 @@ func NewVssSnapshot(
}
err = callAsyncFunctionAndWait(iVssBackupComponents.DoSnapshotSet, "DoSnapshotSet",
timeoutInMillis)
deadline)
if err != nil {
iVssBackupComponents.AbortBackup()
_ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
@@ -875,13 +1015,12 @@ func NewVssSnapshot(
var snapshotProperties VssSnapshotProperties
err = iVssBackupComponents.GetSnapshotProperties(snapshotSetID, &snapshotProperties)
if err != nil {
iVssBackupComponents.AbortBackup()
_ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
for mountPoint, info := range mountPointInfo {
if !info.isSnapshotted {
continue
}
@@ -900,8 +1039,10 @@ func NewVssSnapshot(
mountPointInfo[mountPoint] = info
}
return VssSnapshot{iVssBackupComponents, snapshotSetID, snapshotProperties,
snapshotProperties.GetSnapshotDeviceObject(), mountPointInfo, timeoutInMillis}, nil
return VssSnapshot{
iVssBackupComponents, snapshotSetID, snapshotProperties,
snapshotProperties.GetSnapshotDeviceObject(), mountPointInfo, time.Until(deadline),
}, nil
}
// Delete deletes the created snapshot.
@@ -922,15 +1063,17 @@ func (p *VssSnapshot) Delete() error {
if p.iVssBackupComponents != nil {
defer p.iVssBackupComponents.Release()
deadline := time.Now().Add(p.timeout)
err = callAsyncFunctionAndWait(p.iVssBackupComponents.BackupComplete, "BackupComplete",
p.timeoutInMillis)
deadline)
if err != nil {
return err
}
if _, _, e := p.iVssBackupComponents.DeleteSnapshots(p.snapshotID); e != nil {
err = newVssTextError(fmt.Sprintf("Failed to delete snapshot: %s", e.Error()))
p.iVssBackupComponents.AbortBackup()
_ = p.iVssBackupComponents.AbortBackup()
if err != nil {
return err
}
@@ -940,12 +1083,61 @@ func (p *VssSnapshot) Delete() error {
return nil
}
func getProviderID(provider string) (*ole.GUID, error) {
providerLower := strings.ToLower(provider)
switch providerLower {
case "":
return ole.IID_NULL, nil
case "ms":
return ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}"), nil
}
comInterface, err := ole.CreateInstance(CLSID_VSS_COORDINATOR, UIID_IVSS_ADMIN)
if err != nil {
return nil, err
}
defer comInterface.Release()
vssAdmin := (*IVSSAdmin)(unsafe.Pointer(comInterface))
enum, err := vssAdmin.QueryProviders()
if err != nil {
return nil, err
}
defer enum.Release()
id := ole.NewGUID(provider)
var props struct {
objectType uint32
provider VssProviderProperties
}
for {
count, err := enum.Next(1, unsafe.Pointer(&props))
if err != nil {
return nil, err
}
if count < 1 {
return nil, errors.Errorf(`invalid VSS provider "%s"`, provider)
}
name := ole.UTF16PtrToString(props.provider.providerName)
vssFreeProviderProperties(&props.provider)
if id != nil && *id == props.provider.providerID ||
id == nil && providerLower == strings.ToLower(name) {
return &props.provider.providerID, nil
}
}
}
// asyncCallFunc is the callback type for callAsyncFunctionAndWait.
type asyncCallFunc func() (*IVSSAsync, error)
// callAsyncFunctionAndWait calls an async functions and waits for it to either
// finish or timeout.
func callAsyncFunctionAndWait(function asyncCallFunc, name string, timeoutInMillis uint32) error {
func callAsyncFunctionAndWait(function asyncCallFunc, name string, deadline time.Time) error {
iVssAsync, err := function()
if err != nil {
return err
@@ -955,7 +1147,12 @@ func callAsyncFunctionAndWait(function asyncCallFunc, name string, timeoutInMill
return newVssTextError(fmt.Sprintf("%s() returned nil", name))
}
err = iVssAsync.WaitUntilAsyncFinished(timeoutInMillis)
timeout := time.Until(deadline)
if timeout <= 0 {
return newVssTextError(fmt.Sprintf("%s() deadline exceeded", name))
}
err = iVssAsync.WaitUntilAsyncFinished(timeout)
iVssAsync.Release()
return err
}
@@ -1036,6 +1233,7 @@ func enumerateMountedFolders(volume string) ([]string, error) {
return mountedFolders, nil
}
// nolint:errcheck
defer windows.FindVolumeMountPointClose(handle)
volumeMountPoint := syscall.UTF16ToString(volumeMountPointBuffer)