mirror of
https://github.com/restic/restic.git
synced 2025-12-12 08:22:08 +00:00
vss: Add "provider" option
This commit is contained in:
@@ -367,7 +367,7 @@ func (vss *IVssBackupComponents) convertToVSSAsync(
|
||||
}
|
||||
|
||||
// IsVolumeSupported calls the equivalent VSS api.
|
||||
func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, error) {
|
||||
func (vss *IVssBackupComponents) IsVolumeSupported(providerID *ole.GUID, volumeName string) (bool, error) {
|
||||
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -377,7 +377,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
|
||||
var result uintptr
|
||||
|
||||
if runtime.GOARCH == "386" {
|
||||
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL))
|
||||
id := (*[4]uintptr)(unsafe.Pointer(providerID))
|
||||
|
||||
result, _, _ = syscall.Syscall9(vss.getVTable().isVolumeSupported, 7,
|
||||
uintptr(unsafe.Pointer(vss)), id[0], id[1], id[2], id[3],
|
||||
@@ -385,7 +385,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
|
||||
0)
|
||||
} else {
|
||||
result, _, _ = syscall.Syscall6(vss.getVTable().isVolumeSupported, 4,
|
||||
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(ole.IID_NULL)),
|
||||
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(providerID)),
|
||||
uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(&isSupportedRaw)), 0,
|
||||
0)
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func (vss *IVssBackupComponents) StartSnapshotSet() (ole.GUID, error) {
|
||||
}
|
||||
|
||||
// AddToSnapshotSet calls the equivalent VSS api.
|
||||
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot *ole.GUID) error {
|
||||
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, providerID *ole.GUID, idSnapshot *ole.GUID) error {
|
||||
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -420,15 +420,15 @@ func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot
|
||||
var result uintptr
|
||||
|
||||
if runtime.GOARCH == "386" {
|
||||
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL))
|
||||
id := (*[4]uintptr)(unsafe.Pointer(providerID))
|
||||
|
||||
result, _, _ = syscall.Syscall9(vss.getVTable().addToSnapshotSet, 7,
|
||||
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), id[0], id[1],
|
||||
id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
|
||||
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
|
||||
id[0], id[1], id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
|
||||
} else {
|
||||
result, _, _ = syscall.Syscall6(vss.getVTable().addToSnapshotSet, 4,
|
||||
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
|
||||
uintptr(unsafe.Pointer(ole.IID_NULL)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
|
||||
uintptr(unsafe.Pointer(providerID)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
|
||||
}
|
||||
|
||||
return newVssErrorIfResultNotOK("AddToSnapshotSet() failed", HRESULT(result))
|
||||
@@ -535,6 +535,13 @@ func vssFreeSnapshotProperties(properties *VssSnapshotProperties) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func vssFreeProviderProperties(p *VssProviderProperties) {
|
||||
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerName)))
|
||||
p.providerName = nil
|
||||
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerVersion)))
|
||||
p.providerName = nil
|
||||
}
|
||||
|
||||
// BackupComplete calls the equivalent VSS api.
|
||||
func (vss *IVssBackupComponents) BackupComplete() (*IVSSAsync, error) {
|
||||
var oleIUnknown *ole.IUnknown
|
||||
@@ -563,6 +570,17 @@ type VssSnapshotProperties struct {
|
||||
status uint
|
||||
}
|
||||
|
||||
// VssProviderProperties defines the properties of a VSS provider as part of the VSS api.
|
||||
// nolint:structcheck
|
||||
type VssProviderProperties struct {
|
||||
providerID ole.GUID
|
||||
providerName *uint16
|
||||
providerType uint32
|
||||
providerVersion *uint16
|
||||
providerVersionID ole.GUID
|
||||
classID ole.GUID
|
||||
}
|
||||
|
||||
// GetSnapshotDeviceObject returns root path to access the snapshot files
|
||||
// and folders.
|
||||
func (p *VssSnapshotProperties) GetSnapshotDeviceObject() string {
|
||||
@@ -660,6 +678,75 @@ func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UIID_IVSS_ADMIN defines the GUID of IVSSAdmin.
|
||||
var (
|
||||
UIID_IVSS_ADMIN = ole.NewGUID("{77ED5996-2F63-11d3-8A39-00C04F72D8E3}")
|
||||
CLSID_VSS_COORDINATOR = ole.NewGUID("{E579AB5F-1CC4-44b4-BED9-DE0991FF0623}")
|
||||
)
|
||||
|
||||
// IVSSAdmin VSS api interface.
|
||||
type IVSSAdmin struct {
|
||||
ole.IUnknown
|
||||
}
|
||||
|
||||
// IVSSAdminVTable is the vtable for IVSSAdmin.
|
||||
// nolint:structcheck
|
||||
type IVSSAdminVTable struct {
|
||||
ole.IUnknownVtbl
|
||||
registerProvider uintptr
|
||||
unregisterProvider uintptr
|
||||
queryProviders uintptr
|
||||
abortAllSnapshotsInProgress uintptr
|
||||
}
|
||||
|
||||
// getVTable returns the vtable for IVSSAdmin.
|
||||
func (vssAdmin *IVSSAdmin) getVTable() *IVSSAdminVTable {
|
||||
return (*IVSSAdminVTable)(unsafe.Pointer(vssAdmin.RawVTable))
|
||||
}
|
||||
|
||||
// QueryProviders calls the equivalent VSS api.
|
||||
func (vssAdmin *IVSSAdmin) QueryProviders() (*IVssEnumObject, error) {
|
||||
var enum *IVssEnumObject
|
||||
|
||||
result, _, _ := syscall.Syscall(vssAdmin.getVTable().queryProviders, 2,
|
||||
uintptr(unsafe.Pointer(vssAdmin)), uintptr(unsafe.Pointer(&enum)), 0)
|
||||
|
||||
return enum, newVssErrorIfResultNotOK("QueryProviders() failed", HRESULT(result))
|
||||
}
|
||||
|
||||
// IVssEnumObject VSS api interface.
|
||||
type IVssEnumObject struct {
|
||||
ole.IUnknown
|
||||
}
|
||||
|
||||
// IVssEnumObjectVTable is the vtable for IVssEnumObject.
|
||||
// nolint:structcheck
|
||||
type IVssEnumObjectVTable struct {
|
||||
ole.IUnknownVtbl
|
||||
next uintptr
|
||||
skip uintptr
|
||||
reset uintptr
|
||||
clone uintptr
|
||||
}
|
||||
|
||||
// getVTable returns the vtable for IVssEnumObject.
|
||||
func (vssEnum *IVssEnumObject) getVTable() *IVssEnumObjectVTable {
|
||||
return (*IVssEnumObjectVTable)(unsafe.Pointer(vssEnum.RawVTable))
|
||||
}
|
||||
|
||||
// Next calls the equivalent VSS api.
|
||||
func (vssEnum *IVssEnumObject) Next(count uint, props unsafe.Pointer) (uint, error) {
|
||||
var fetched uint32
|
||||
result, _, _ := syscall.Syscall6(vssEnum.getVTable().next, 4,
|
||||
uintptr(unsafe.Pointer(vssEnum)), uintptr(count), uintptr(props),
|
||||
uintptr(unsafe.Pointer(&fetched)), 0, 0)
|
||||
if result == 1 {
|
||||
return uint(fetched), nil
|
||||
}
|
||||
|
||||
return uint(fetched), newVssErrorIfResultNotOK("Next() failed", HRESULT(result))
|
||||
}
|
||||
|
||||
// MountPoint wraps all information of a snapshot of a mountpoint on a volume.
|
||||
type MountPoint struct {
|
||||
isSnapshotted bool
|
||||
@@ -766,7 +853,7 @@ func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
|
||||
|
||||
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
|
||||
// finish within the timeout an error is returned.
|
||||
func NewVssSnapshot(
|
||||
func NewVssSnapshot(provider string,
|
||||
volume string, timeout time.Duration, filter VolumeFilter, msgError ErrorHandler) (VssSnapshot, error) {
|
||||
is64Bit, err := isRunningOn64BitWindows()
|
||||
if err != nil {
|
||||
@@ -814,6 +901,12 @@ func NewVssSnapshot(
|
||||
|
||||
iVssBackupComponents := (*IVssBackupComponents)(unsafe.Pointer(comInterface))
|
||||
|
||||
providerID, err := getProviderID(provider)
|
||||
if err != nil {
|
||||
iVssBackupComponents.Release()
|
||||
return VssSnapshot{}, err
|
||||
}
|
||||
|
||||
if err := iVssBackupComponents.InitializeForBackup(); err != nil {
|
||||
iVssBackupComponents.Release()
|
||||
return VssSnapshot{}, err
|
||||
@@ -838,7 +931,7 @@ func NewVssSnapshot(
|
||||
return VssSnapshot{}, err
|
||||
}
|
||||
|
||||
if isSupported, err := iVssBackupComponents.IsVolumeSupported(volume); err != nil {
|
||||
if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, volume); err != nil {
|
||||
iVssBackupComponents.Release()
|
||||
return VssSnapshot{}, err
|
||||
} else if !isSupported {
|
||||
@@ -853,7 +946,7 @@ func NewVssSnapshot(
|
||||
return VssSnapshot{}, err
|
||||
}
|
||||
|
||||
if err := iVssBackupComponents.AddToSnapshotSet(volume, &snapshotSetID); err != nil {
|
||||
if err := iVssBackupComponents.AddToSnapshotSet(volume, providerID, &snapshotSetID); err != nil {
|
||||
iVssBackupComponents.Release()
|
||||
return VssSnapshot{}, err
|
||||
}
|
||||
@@ -877,14 +970,14 @@ func NewVssSnapshot(
|
||||
|
||||
if !filter(mountPoint) {
|
||||
continue
|
||||
} else if isSupported, err := iVssBackupComponents.IsVolumeSupported(mountPoint); err != nil {
|
||||
} else if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, mountPoint); err != nil {
|
||||
continue
|
||||
} else if !isSupported {
|
||||
continue
|
||||
}
|
||||
|
||||
var mountPointSnapshotSetID ole.GUID
|
||||
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, &mountPointSnapshotSetID)
|
||||
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, providerID, &mountPointSnapshotSetID)
|
||||
if err != nil {
|
||||
iVssBackupComponents.Release()
|
||||
|
||||
@@ -988,6 +1081,55 @@ func (p *VssSnapshot) Delete() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getProviderID(provider string) (*ole.GUID, error) {
|
||||
comInterface, err := ole.CreateInstance(CLSID_VSS_COORDINATOR, UIID_IVSS_ADMIN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer comInterface.Release()
|
||||
|
||||
vssAdmin := (*IVSSAdmin)(unsafe.Pointer(comInterface))
|
||||
|
||||
providerLower := strings.ToLower(provider)
|
||||
switch providerLower {
|
||||
case "":
|
||||
return ole.IID_NULL, nil
|
||||
case "ms":
|
||||
return ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}"), nil
|
||||
}
|
||||
|
||||
enum, err := vssAdmin.QueryProviders()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer enum.Release()
|
||||
|
||||
id := ole.NewGUID(provider)
|
||||
|
||||
var props struct {
|
||||
objectType uint32
|
||||
provider VssProviderProperties
|
||||
}
|
||||
for {
|
||||
count, err := enum.Next(1, unsafe.Pointer(&props))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if count < 1 {
|
||||
return nil, errors.Errorf(`invalid VSS provider "%s"`, provider)
|
||||
}
|
||||
|
||||
name := ole.UTF16PtrToString(props.provider.providerName)
|
||||
vssFreeProviderProperties(&props.provider)
|
||||
|
||||
if id != nil && *id == props.provider.providerID ||
|
||||
id == nil && providerLower == strings.ToLower(name) {
|
||||
return &props.provider.providerID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncCallFunc is the callback type for callAsyncFunctionAndWait.
|
||||
type asyncCallFunc func() (*IVSSAsync, error)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user