This commit is contained in:
DRON-666 2024-04-24 23:58:42 +00:00 committed by GitHub
commit 18cb0fabd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 747 additions and 96 deletions

View File

@ -0,0 +1,22 @@
Enhancement: Add options to configure Windows Shadow Copy Service
Restic always used 120 sec. timeout and unconditionally created VSS snapshots
for all volume mount points on disk. Now this behavior can be fine-tuned by
new options, like exclude user specific volumes and mount points or completely
disable auto snapshotting of volume mount points.
For example:
restic backup --use-fs-snapshot -o vss.timeout=5m -o vss.excludeallmountpoints=true
changes timeout to five minutes and disable snapshotting of mount points on all volumes, and
restic backup --use-fs-snapshot -o vss.excludevolumes="d:\;c:\mnt\;\\?\Volume{e2e0315d-9066-4f97-8343-eb5659b35762}"
excludes drive `D:`, mount point `C:\MNT` and specific volume from VSS snapshotting.
restic backup --use-fs-snapshot -o vss.provider={b5946137-7b9f-4925-af80-51abd60b20d5}
uses 'Microsoft Software Shadow Copy provider 1.0' instead of the default provider.
https://github.com/restic/restic/pull/3067

View File

@ -445,7 +445,16 @@ func findParentSnapshot(ctx context.Context, repo restic.ListerLoaderUnpacked, o
}
func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, term *termstatus.Terminal, args []string) error {
err := opts.Check(gopts, args)
var vsscfg fs.VSSConfig
var err error
if runtime.GOOS == "windows" {
if vsscfg, err = fs.ParseVSSConfig(gopts.extended); err != nil {
return err
}
}
err = opts.Check(gopts, args)
if err != nil {
return err
}
@ -547,8 +556,8 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter
return err
}
errorHandler := func(item string, err error) error {
return progressReporter.Error(item, err)
errorHandler := func(item string, err error) {
_ = progressReporter.Error(item, err)
}
messageHandler := func(msg string, args ...interface{}) {
@ -557,7 +566,7 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter
}
}
localVss := fs.NewLocalVss(errorHandler, messageHandler)
localVss := fs.NewLocalVss(errorHandler, messageHandler, vsscfg)
defer localVss.DeleteSnapshots()
targetFS = localVss
}

View File

@ -56,6 +56,39 @@ snapshot for each volume that contains files to backup. Files are read from the
VSS snapshot instead of the regular filesystem. This allows to backup files that are
exclusively locked by another process during the backup.
You can use additional options to change VSS behaviour:
* ``-o vss.timeout`` specifies timeout for VSS snapshot creation, the default value is 120 seconds
* ``-o vss.excludeallmountpoints`` disable auto snapshotting of all volume mount points
* ``-o vss.excludevolumes`` allows excluding specific volumes or volume mount points from snapshotting
* ``-o vss.provider`` specifies VSS provider used for snapshotting
E.g., 2.5 minutes timeout with mount points snapshotting disabled can be specified as
.. code-block:: console
-o vss.timeout=2m30s -o vss.excludeallmountpoints=true
and excluding drive ``D:\``, mount point ``C:\mnt`` and volume ``\\?\Volume{04ce0545-3391-11e0-ba2f-806e6f6e6963}\`` as
.. code-block:: console
-o vss.excludevolumes="d:;c:\MNT\;\\?\volume{04ce0545-3391-11e0-ba2f-806e6f6e6963}"
VSS provider can be specified by GUID
.. code-block:: console
-o vss.provider={3f900f90-00e9-440e-873a-96ca5eb079e5}
or by name
.. code-block:: console
-o vss.provider="Hyper-V IC Software Shadow Copy Provider"
Also ``MS`` can be used as alias for ``Microsoft Software Shadow Copy provider 1.0``.
By default VSS ignores Outlook OST files. This is not a restriction of restic
but the default Windows VSS configuration. The files not to snapshot are
configured in the Windows registry under the following key:

View File

@ -3,41 +3,108 @@ package fs
import (
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/options"
)
// ErrorHandler is used to report errors via callback
type ErrorHandler func(item string, err error) error
// VSSConfig holds extended options of windows volume shadow copy service.
type VSSConfig struct {
ExcludeAllMountPoints bool `option:"excludeallmountpoints" help:"exclude mountpoints from snapshotting on all volumes"`
ExcludeVolumes string `option:"excludevolumes" help:"semicolon separated list of volumes to exclude from snapshotting (ex. 'c:\\;e:\\mnt;\\\\?\\Volume{...}')"`
Timeout time.Duration `option:"timeout" help:"time that the VSS can spend creating snapshot before timing out"`
Provider string `option:"provider" help:"VSS provider identifier which will be used for snapshotting"`
}
func init() {
if runtime.GOOS == "windows" {
options.Register("vss", VSSConfig{})
}
}
// NewVSSConfig returns a new VSSConfig with the default values filled in.
func NewVSSConfig() VSSConfig {
return VSSConfig{
Timeout: time.Second * 120,
}
}
// ParseVSSConfig parses a VSS extended options to VSSConfig struct.
func ParseVSSConfig(o options.Options) (VSSConfig, error) {
cfg := NewVSSConfig()
o = o.Extract("vss")
if err := o.Apply("vss", &cfg); err != nil {
return VSSConfig{}, err
}
return cfg, nil
}
// ErrorHandler is used to report errors via callback.
type ErrorHandler func(item string, err error)
// MessageHandler is used to report errors/messages via callbacks.
type MessageHandler func(msg string, args ...interface{})
// VolumeFilter is used to filter volumes by it's mount point or GUID path.
type VolumeFilter func(volume string) bool
// LocalVss is a wrapper around the local file system which uses windows volume
// shadow copy service (VSS) in a transparent way.
type LocalVss struct {
FS
snapshots map[string]VssSnapshot
failedSnapshots map[string]struct{}
mutex sync.RWMutex
msgError ErrorHandler
msgMessage MessageHandler
snapshots map[string]VssSnapshot
failedSnapshots map[string]struct{}
mutex sync.RWMutex
msgError ErrorHandler
msgMessage MessageHandler
excludeAllMountPoints bool
excludeVolumes map[string]struct{}
timeout time.Duration
provider string
}
// statically ensure that LocalVss implements FS.
var _ FS = &LocalVss{}
// parseMountPoints try to convert semicolon separated list of mount points
// to map of lowercased volume GUID pathes. Mountpoints already in volume
// GUID path format will be validated and normalized.
func parseMountPoints(list string, msgError ErrorHandler) (volumes map[string]struct{}) {
if list == "" {
return
}
for _, s := range strings.Split(list, ";") {
if v, err := GetVolumeNameForVolumeMountPoint(s); err != nil {
msgError(s, errors.Errorf("failed to parse vss.excludevolumes [%s]: %s", s, err))
} else {
if volumes == nil {
volumes = make(map[string]struct{})
}
volumes[strings.ToLower(v)] = struct{}{}
}
}
return
}
// NewLocalVss creates a new wrapper around the windows filesystem using volume
// shadow copy service to access locked files.
func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler) *LocalVss {
func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler, cfg VSSConfig) *LocalVss {
return &LocalVss{
FS: Local{},
snapshots: make(map[string]VssSnapshot),
failedSnapshots: make(map[string]struct{}),
msgError: msgError,
msgMessage: msgMessage,
FS: Local{},
snapshots: make(map[string]VssSnapshot),
failedSnapshots: make(map[string]struct{}),
msgError: msgError,
msgMessage: msgMessage,
excludeAllMountPoints: cfg.ExcludeAllMountPoints,
excludeVolumes: parseMountPoints(cfg.ExcludeVolumes, msgError),
timeout: cfg.Timeout,
provider: cfg.Provider,
}
}
@ -50,7 +117,7 @@ func (fs *LocalVss) DeleteSnapshots() {
for volumeName, snapshot := range fs.snapshots {
if err := snapshot.Delete(); err != nil {
_ = fs.msgError(volumeName, errors.Errorf("failed to delete VSS snapshot: %s", err))
fs.msgError(volumeName, errors.Errorf("failed to delete VSS snapshot: %s", err))
activeSnapshots[volumeName] = snapshot
}
}
@ -78,12 +145,29 @@ func (fs *LocalVss) Lstat(name string) (os.FileInfo, error) {
return os.Lstat(fs.snapshotPath(name))
}
// isMountPointExcluded is true if given mountpoint excluded by user.
func (fs *LocalVss) isMountPointExcluded(mountPoint string) bool {
if fs.excludeVolumes == nil {
return false
}
volume, err := GetVolumeNameForVolumeMountPoint(mountPoint)
if err != nil {
fs.msgError(mountPoint, errors.Errorf("failed to get volume from mount point [%s]: %s", mountPoint, err))
return false
}
_, ok := fs.excludeVolumes[strings.ToLower(volume)]
return ok
}
// snapshotPath returns the path inside a VSS snapshots if it already exists.
// If the path is not yet available as a snapshot, a snapshot is created.
// If creation of a snapshot fails the file's original path is returned as
// a fallback.
func (fs *LocalVss) snapshotPath(path string) string {
fixPath := fixpath(path)
if strings.HasPrefix(fixPath, `\\?\UNC\`) {
@ -114,23 +198,36 @@ func (fs *LocalVss) snapshotPath(path string) string {
if !snapshotExists && !snapshotFailed {
vssVolume := volumeNameLower + string(filepath.Separator)
fs.msgMessage("creating VSS snapshot for [%s]\n", vssVolume)
if snapshot, err := NewVssSnapshot(vssVolume, 120, fs.msgError); err != nil {
_ = fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s",
vssVolume, err))
if fs.isMountPointExcluded(vssVolume) {
fs.msgMessage("snapshots for [%s] excluded by user\n", vssVolume)
fs.failedSnapshots[volumeNameLower] = struct{}{}
} else {
fs.snapshots[volumeNameLower] = snapshot
fs.msgMessage("successfully created snapshot for [%s]\n", vssVolume)
if len(snapshot.mountPointInfo) > 0 {
fs.msgMessage("mountpoints in snapshot volume [%s]:\n", vssVolume)
for mp, mpInfo := range snapshot.mountPointInfo {
info := ""
if !mpInfo.IsSnapshotted() {
info = " (not snapshotted)"
fs.msgMessage("creating VSS snapshot for [%s]\n", vssVolume)
var filter VolumeFilter
if !fs.excludeAllMountPoints {
filter = func(volume string) bool {
return !fs.isMountPointExcluded(volume)
}
}
if snapshot, err := NewVssSnapshot(fs.provider, vssVolume, fs.timeout, filter, fs.msgError); err != nil {
fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s",
vssVolume, err))
fs.failedSnapshots[volumeNameLower] = struct{}{}
} else {
fs.snapshots[volumeNameLower] = snapshot
fs.msgMessage("successfully created snapshot for [%s]\n", vssVolume)
if len(snapshot.mountPointInfo) > 0 {
fs.msgMessage("mountpoints in snapshot volume [%s]:\n", vssVolume)
for mp, mpInfo := range snapshot.mountPointInfo {
info := ""
if !mpInfo.IsSnapshotted() {
info = " (not snapshotted)"
}
fs.msgMessage(" - %s%s\n", mp, info)
}
fs.msgMessage(" - %s%s\n", mp, info)
}
}
}
@ -173,9 +270,8 @@ func (fs *LocalVss) snapshotPath(path string) string {
snapshotPath = fs.Join(snapshot.GetSnapshotDeviceObject(),
strings.TrimPrefix(fixPath, volumeName))
if snapshotPath == snapshot.GetSnapshotDeviceObject() {
snapshotPath = snapshotPath + string(filepath.Separator)
snapshotPath += string(filepath.Separator)
}
} else {
// no snapshot is available for the requested path:
// -> try to backup without a snapshot

View File

@ -0,0 +1,287 @@
// +build windows
package fs
import (
"fmt"
"regexp"
"strings"
"testing"
"time"
ole "github.com/go-ole/go-ole"
"github.com/restic/restic/internal/options"
)
func matchStrings(ptrs []string, strs []string) bool {
if len(ptrs) != len(strs) {
return false
}
for i, p := range ptrs {
if p == "" {
return false
}
matched, err := regexp.MatchString(p, strs[i])
if err != nil {
panic(err)
}
if !matched {
return false
}
}
return true
}
func matchMap(strs []string, m map[string]struct{}) bool {
if len(strs) != len(m) {
return false
}
for _, s := range strs {
if _, ok := m[s]; !ok {
return false
}
}
return true
}
func TestVSSConfig(t *testing.T) {
type config struct {
excludeAllMountPoints bool
timeout time.Duration
provider string
}
setTests := []struct {
input options.Options
output config
}{
{
options.Options{
"vss.timeout": "6h38m42s",
"vss.provider": "Ms",
},
config{
timeout: 23922000000000,
provider: "Ms",
},
},
{
options.Options{
"vss.excludeallmountpoints": "t",
"vss.provider": "{b5946137-7b9f-4925-af80-51abd60b20d5}",
},
config{
excludeAllMountPoints: true,
timeout: 120000000000,
provider: "{b5946137-7b9f-4925-af80-51abd60b20d5}",
},
},
{
options.Options{
"vss.excludeallmountpoints": "0",
"vss.excludevolumes": "",
"vss.timeout": "120s",
"vss.provider": "Microsoft Software Shadow Copy provider 1.0",
},
config{
timeout: 120000000000,
provider: "Microsoft Software Shadow Copy provider 1.0",
},
},
}
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
cfg, err := ParseVSSConfig(test.input)
if err != nil {
t.Fatal(err)
}
errorHandler := func(item string, err error) {
t.Fatalf("unexpected error (%v)", err)
}
messageHandler := func(msg string, args ...interface{}) {
t.Fatalf("unexpected message (%s)", fmt.Sprintf(msg, args))
}
dst := NewLocalVss(errorHandler, messageHandler, cfg)
if dst.excludeAllMountPoints != test.output.excludeAllMountPoints ||
dst.excludeVolumes != nil || dst.timeout != test.output.timeout ||
dst.provider != test.output.provider {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v", test.output, dst)
}
})
}
}
func TestParseMountPoints(t *testing.T) {
volumeMatch := regexp.MustCompile(`^\\\\\?\\Volume\{[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}\}\\$`)
// It's not a good idea to test functions based on GetVolumeNameForVolumeMountPoint by calling
// GetVolumeNameForVolumeMountPoint itself, but we have restricted test environment:
// cannot manage volumes and can only be sure that the mount point C:\ exists
sysVolume, err := GetVolumeNameForVolumeMountPoint("C:")
if err != nil {
t.Fatal(err)
}
// We don't know a valid volume GUID path for C:\, but we'll at least check its format
if !volumeMatch.MatchString(sysVolume) {
t.Fatalf("invalid volume GUID path: %s", sysVolume)
}
sysVolumeMutated := strings.ToUpper(sysVolume[:len(sysVolume)-1])
sysVolumeMatch := strings.ToLower(sysVolume)
type check struct {
volume string
result bool
}
setTests := []struct {
input options.Options
output []string
checks []check
errors []string
}{
{
options.Options{
"vss.excludevolumes": `c:;c:\;` + sysVolume + `;` + sysVolumeMutated,
},
[]string{
sysVolumeMatch,
},
[]check{
{`c:\`, true},
{`c:`, true},
{sysVolume, true},
{sysVolumeMutated, true},
},
[]string{},
},
{
options.Options{
"vss.excludevolumes": `z:\nonexistent;c:;c:\windows\;\\?\Volume{39b9cac2-bcdb-4d51-97c8-0d0677d607fb}\`,
},
[]string{
sysVolumeMatch,
},
[]check{
{`c:\windows\`, false},
{`\\?\Volume{39b9cac2-bcdb-4d51-97c8-0d0677d607fb}\`, false},
{`c:`, true},
{``, false},
},
[]string{
`failed to parse vss\.excludevolumes \[z:\\nonexistent\]:.*`,
`failed to parse vss\.excludevolumes \[c:\\windows\\\]:.*`,
`failed to parse vss\.excludevolumes \[\\\\\?\\Volume\{39b9cac2-bcdb-4d51-97c8-0d0677d607fb\}\\\]:.*`,
`failed to get volume from mount point \[c:\\windows\\\]:.*`,
`failed to get volume from mount point \[\\\\\?\\Volume\{39b9cac2-bcdb-4d51-97c8-0d0677d607fb\}\\\]:.*`,
`failed to get volume from mount point \[\]:.*`,
},
},
}
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
cfg, err := ParseVSSConfig(test.input)
if err != nil {
t.Fatal(err)
}
var log []string
errorHandler := func(item string, err error) {
log = append(log, strings.TrimSpace(err.Error()))
}
messageHandler := func(msg string, args ...interface{}) {
t.Fatalf("unexpected message (%s)", fmt.Sprintf(msg, args))
}
dst := NewLocalVss(errorHandler, messageHandler, cfg)
if !matchMap(test.output, dst.excludeVolumes) {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v",
test.output, dst.excludeVolumes)
}
for _, c := range test.checks {
if dst.isMountPointExcluded(c.volume) != c.result {
t.Fatalf(`wrong check: isMountPointExcluded("%s") != %v`, c.volume, c.result)
}
}
if !matchStrings(test.errors, log) {
t.Fatalf("wrong log, want:\n %#v\ngot:\n %#v", test.errors, log)
}
})
}
}
func TestParseProvider(t *testing.T) {
msProvider := ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}")
setTests := []struct {
provider string
id *ole.GUID
result string
}{
{
"",
ole.IID_NULL,
"",
},
{
"mS",
msProvider,
"",
},
{
"{B5946137-7b9f-4925-Af80-51abD60b20d5}",
msProvider,
"",
},
{
"Microsoft Software Shadow Copy provider 1.0",
msProvider,
"",
},
{
"{04560982-3d7d-4bbc-84f7-0712f833a28f}",
nil,
`invalid VSS provider "{04560982-3d7d-4bbc-84f7-0712f833a28f}"`,
},
{
"non-existent provider",
nil,
`invalid VSS provider "non-existent provider"`,
},
}
_ = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
id, err := getProviderID(test.provider)
if err != nil && id != nil {
t.Fatalf("err!=nil but id=%v", id)
}
if test.result != "" || err != nil {
var result string
if err != nil {
result = err.Error()
}
matched, err := regexp.MatchString(test.result, result)
if err != nil {
panic(err)
}
if !matched || test.result == "" {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v", test.result, result)
}
} else if !ole.IsEqualGUID(id, test.id) {
t.Fatalf("wrong id, want:\n %s\ngot:\n %s", test.id.String(), id.String())
}
})
}
}

View File

@ -4,6 +4,8 @@
package fs
import (
"time"
"github.com/restic/restic/internal/errors"
)
@ -31,10 +33,16 @@ func HasSufficientPrivilegesForVSS() error {
return errors.New("VSS snapshots are only supported on windows")
}
// GetVolumeNameForVolumeMountPoint clear input parameter
// and calls the equivalent windows api.
func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
return mountPoint, nil
}
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
// finish within the timeout an error is returned.
func NewVssSnapshot(
_ string, _ uint, _ ErrorHandler) (VssSnapshot, error) {
func NewVssSnapshot(_ string,
_ string, _ time.Duration, _ VolumeFilter, _ ErrorHandler) (VssSnapshot, error) {
return VssSnapshot{}, errors.New("VSS snapshots are only supported on windows")
}

View File

@ -9,6 +9,7 @@ import (
"runtime"
"strings"
"syscall"
"time"
"unsafe"
ole "github.com/go-ole/go-ole"
@ -20,6 +21,7 @@ import (
type HRESULT uint
// HRESULT constant values necessary for using VSS api.
//nolint:golint
const (
S_OK HRESULT = 0x00000000
E_ACCESSDENIED HRESULT = 0x80070005
@ -255,6 +257,7 @@ type IVssBackupComponents struct {
}
// IVssBackupComponentsVTable is the vtable for IVssBackupComponents.
// nolint:structcheck
type IVssBackupComponentsVTable struct {
ole.IUnknownVtbl
getWriterComponentsCount uintptr
@ -364,7 +367,7 @@ func (vss *IVssBackupComponents) convertToVSSAsync(
}
// IsVolumeSupported calls the equivalent VSS api.
func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, error) {
func (vss *IVssBackupComponents) IsVolumeSupported(providerID *ole.GUID, volumeName string) (bool, error) {
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
if err != nil {
panic(err)
@ -374,7 +377,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL))
id := (*[4]uintptr)(unsafe.Pointer(providerID))
result, _, _ = syscall.Syscall9(vss.getVTable().isVolumeSupported, 7,
uintptr(unsafe.Pointer(vss)), id[0], id[1], id[2], id[3],
@ -382,7 +385,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
0)
} else {
result, _, _ = syscall.Syscall6(vss.getVTable().isVolumeSupported, 4,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(ole.IID_NULL)),
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(providerID)),
uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(&isSupportedRaw)), 0,
0)
}
@ -408,24 +411,24 @@ func (vss *IVssBackupComponents) StartSnapshotSet() (ole.GUID, error) {
}
// AddToSnapshotSet calls the equivalent VSS api.
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot *ole.GUID) error {
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, providerID *ole.GUID, idSnapshot *ole.GUID) error {
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
if err != nil {
panic(err)
}
var result uintptr = 0
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL))
id := (*[4]uintptr)(unsafe.Pointer(providerID))
result, _, _ = syscall.Syscall9(vss.getVTable().addToSnapshotSet, 7,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), id[0], id[1],
id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
id[0], id[1], id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
} else {
result, _, _ = syscall.Syscall6(vss.getVTable().addToSnapshotSet, 4,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
uintptr(unsafe.Pointer(ole.IID_NULL)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
uintptr(unsafe.Pointer(providerID)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
}
return newVssErrorIfResultNotOK("AddToSnapshotSet() failed", HRESULT(result))
@ -478,9 +481,9 @@ func (vss *IVssBackupComponents) DoSnapshotSet() (*IVSSAsync, error) {
// DeleteSnapshots calls the equivalent VSS api.
func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ole.GUID, error) {
var deletedSnapshots int32 = 0
var deletedSnapshots int32
var nondeletedSnapshotID ole.GUID
var result uintptr = 0
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(&snapshotID))
@ -504,7 +507,7 @@ func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ol
// GetSnapshotProperties calls the equivalent VSS api.
func (vss *IVssBackupComponents) GetSnapshotProperties(snapshotID ole.GUID,
properties *VssSnapshotProperties) error {
var result uintptr = 0
var result uintptr
if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(&snapshotID))
@ -527,11 +530,18 @@ func vssFreeSnapshotProperties(properties *VssSnapshotProperties) error {
if err != nil {
return err
}
proc.Call(uintptr(unsafe.Pointer(properties)))
// this function always succeeds and returns no value
_, _, _ = proc.Call(uintptr(unsafe.Pointer(properties)))
return nil
}
func vssFreeProviderProperties(p *VssProviderProperties) {
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerName)))
p.providerName = nil
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerVersion)))
p.providerName = nil
}
// BackupComplete calls the equivalent VSS api.
func (vss *IVssBackupComponents) BackupComplete() (*IVSSAsync, error) {
var oleIUnknown *ole.IUnknown
@ -543,6 +553,7 @@ func (vss *IVssBackupComponents) BackupComplete() (*IVSSAsync, error) {
}
// VssSnapshotProperties defines the properties of a VSS snapshot as part of the VSS api.
// nolint:structcheck
type VssSnapshotProperties struct {
snapshotID ole.GUID
snapshotSetID ole.GUID
@ -559,6 +570,17 @@ type VssSnapshotProperties struct {
status uint
}
// VssProviderProperties defines the properties of a VSS provider as part of the VSS api.
// nolint:structcheck
type VssProviderProperties struct {
providerID ole.GUID
providerName *uint16
providerType uint32
providerVersion *uint16
providerVersionID ole.GUID
classID ole.GUID
}
// GetSnapshotDeviceObject returns root path to access the snapshot files
// and folders.
func (p *VssSnapshotProperties) GetSnapshotDeviceObject() string {
@ -617,8 +639,13 @@ func (vssAsync *IVSSAsync) QueryStatus() (HRESULT, uint32) {
// WaitUntilAsyncFinished waits until either the async call is finished or
// the given timeout is reached.
func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(millis uint32) error {
hresult := vssAsync.Wait(millis)
func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(timeout time.Duration) error {
const maxTimeout = 2147483647 * time.Millisecond
if timeout > maxTimeout {
timeout = maxTimeout
}
hresult := vssAsync.Wait(uint32(timeout.Milliseconds()))
err := newVssErrorIfResultNotOK("Wait() failed", hresult)
if err != nil {
vssAsync.Cancel()
@ -651,6 +678,75 @@ func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(millis uint32) error {
return nil
}
// UIID_IVSS_ADMIN defines the GUID of IVSSAdmin.
var (
UIID_IVSS_ADMIN = ole.NewGUID("{77ED5996-2F63-11d3-8A39-00C04F72D8E3}")
CLSID_VSS_COORDINATOR = ole.NewGUID("{E579AB5F-1CC4-44b4-BED9-DE0991FF0623}")
)
// IVSSAdmin VSS api interface.
type IVSSAdmin struct {
ole.IUnknown
}
// IVSSAdminVTable is the vtable for IVSSAdmin.
// nolint:structcheck
type IVSSAdminVTable struct {
ole.IUnknownVtbl
registerProvider uintptr
unregisterProvider uintptr
queryProviders uintptr
abortAllSnapshotsInProgress uintptr
}
// getVTable returns the vtable for IVSSAdmin.
func (vssAdmin *IVSSAdmin) getVTable() *IVSSAdminVTable {
return (*IVSSAdminVTable)(unsafe.Pointer(vssAdmin.RawVTable))
}
// QueryProviders calls the equivalent VSS api.
func (vssAdmin *IVSSAdmin) QueryProviders() (*IVssEnumObject, error) {
var enum *IVssEnumObject
result, _, _ := syscall.Syscall(vssAdmin.getVTable().queryProviders, 2,
uintptr(unsafe.Pointer(vssAdmin)), uintptr(unsafe.Pointer(&enum)), 0)
return enum, newVssErrorIfResultNotOK("QueryProviders() failed", HRESULT(result))
}
// IVssEnumObject VSS api interface.
type IVssEnumObject struct {
ole.IUnknown
}
// IVssEnumObjectVTable is the vtable for IVssEnumObject.
// nolint:structcheck
type IVssEnumObjectVTable struct {
ole.IUnknownVtbl
next uintptr
skip uintptr
reset uintptr
clone uintptr
}
// getVTable returns the vtable for IVssEnumObject.
func (vssEnum *IVssEnumObject) getVTable() *IVssEnumObjectVTable {
return (*IVssEnumObjectVTable)(unsafe.Pointer(vssEnum.RawVTable))
}
// Next calls the equivalent VSS api.
func (vssEnum *IVssEnumObject) Next(count uint, props unsafe.Pointer) (uint, error) {
var fetched uint32
result, _, _ := syscall.Syscall6(vssEnum.getVTable().next, 4,
uintptr(unsafe.Pointer(vssEnum)), uintptr(count), uintptr(props),
uintptr(unsafe.Pointer(&fetched)), 0, 0)
if result == 1 {
return uint(fetched), nil
}
return uint(fetched), newVssErrorIfResultNotOK("Next() failed", HRESULT(result))
}
// MountPoint wraps all information of a snapshot of a mountpoint on a volume.
type MountPoint struct {
isSnapshotted bool
@ -677,7 +773,7 @@ type VssSnapshot struct {
snapshotProperties VssSnapshotProperties
snapshotDeviceObject string
mountPointInfo map[string]MountPoint
timeoutInMillis uint32
timeout time.Duration
}
// GetSnapshotDeviceObject returns root path to access the snapshot files
@ -694,7 +790,12 @@ func initializeVssCOMInterface() (*ole.IUnknown, error) {
}
// ensure COM is initialized before use
ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
// CoInitializeEx returns 1 if COM is already initialized
if oleErr, ok := err.(*ole.OleError); !ok || oleErr.Code() != 1 {
return nil, err
}
}
var oleIUnknown *ole.IUnknown
result, _, _ := vssInstance.Call(uintptr(unsafe.Pointer(&oleIUnknown)))
@ -727,12 +828,34 @@ func HasSufficientPrivilegesForVSS() error {
return err
}
// GetVolumeNameForVolumeMountPoint clear input parameter
// and calls the equivalent windows api.
func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
if mountPoint != "" && mountPoint[len(mountPoint)-1] != filepath.Separator {
mountPoint += string(filepath.Separator)
}
mountPointPointer, err := syscall.UTF16PtrFromString(mountPoint)
if err != nil {
return mountPoint, err
}
// A reasonable size for the buffer to accommodate the largest possible
// volume GUID path is 50 characters.
volumeNameBuffer := make([]uint16, 50)
if err := windows.GetVolumeNameForVolumeMountPoint(
mountPointPointer, &volumeNameBuffer[0], 50); err != nil {
return mountPoint, err
}
return syscall.UTF16ToString(volumeNameBuffer), nil
}
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
// finish within the timeout an error is returned.
func NewVssSnapshot(
volume string, timeoutInSeconds uint, msgError ErrorHandler) (VssSnapshot, error) {
func NewVssSnapshot(provider string,
volume string, timeout time.Duration, filter VolumeFilter, msgError ErrorHandler) (VssSnapshot, error) {
is64Bit, err := isRunningOn64BitWindows()
if err != nil {
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"Failed to detect windows architecture: %s", err.Error()))
@ -744,7 +867,7 @@ func NewVssSnapshot(
runtime.GOARCH))
}
timeoutInMillis := uint32(timeoutInSeconds * 1000)
deadline := time.Now().Add(timeout)
oleIUnknown, err := initializeVssCOMInterface()
if oleIUnknown != nil {
@ -778,6 +901,12 @@ func NewVssSnapshot(
iVssBackupComponents := (*IVssBackupComponents)(unsafe.Pointer(comInterface))
providerID, err := getProviderID(provider)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
if err := iVssBackupComponents.InitializeForBackup(); err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
@ -796,13 +925,13 @@ func NewVssSnapshot(
}
err = callAsyncFunctionAndWait(iVssBackupComponents.GatherWriterMetadata,
"GatherWriterMetadata", timeoutInMillis)
"GatherWriterMetadata", deadline)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
if isSupported, err := iVssBackupComponents.IsVolumeSupported(volume); err != nil {
if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, volume); err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
} else if !isSupported {
@ -817,57 +946,66 @@ func NewVssSnapshot(
return VssSnapshot{}, err
}
if err := iVssBackupComponents.AddToSnapshotSet(volume, &snapshotSetID); err != nil {
if err := iVssBackupComponents.AddToSnapshotSet(volume, providerID, &snapshotSetID); err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
mountPoints, err := enumerateMountedFolders(volume)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"failed to enumerate mount points for volume %s: %s", volume, err))
}
mountPointInfo := make(map[string]MountPoint)
for _, mountPoint := range mountPoints {
// ensure every mountpoint is available even without a valid
// snapshot because we need to consider this when backing up files
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: false}
if isSupported, err := iVssBackupComponents.IsVolumeSupported(mountPoint); err != nil {
continue
} else if !isSupported {
continue
}
var mountPointSnapshotSetID ole.GUID
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, &mountPointSnapshotSetID)
// if filter==nil just don't process mount points for this volume at all
if filter != nil {
mountPoints, err := enumerateMountedFolders(volume)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"failed to enumerate mount points for volume %s: %s", volume, err))
}
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: true,
snapshotSetID: mountPointSnapshotSetID}
for _, mountPoint := range mountPoints {
// ensure every mountpoint is available even without a valid
// snapshot because we need to consider this when backing up files
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: false}
if !filter(mountPoint) {
continue
} else if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, mountPoint); err != nil {
continue
} else if !isSupported {
continue
}
var mountPointSnapshotSetID ole.GUID
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, providerID, &mountPointSnapshotSetID)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
mountPointInfo[mountPoint] = MountPoint{
isSnapshotted: true,
snapshotSetID: mountPointSnapshotSetID,
}
}
}
err = callAsyncFunctionAndWait(iVssBackupComponents.PrepareForBackup, "PrepareForBackup",
timeoutInMillis)
deadline)
if err != nil {
// After calling PrepareForBackup one needs to call AbortBackup() before releasing the VSS
// instance for proper cleanup.
// It is not necessary to call BackupComplete before releasing the VSS instance afterwards.
iVssBackupComponents.AbortBackup()
_ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
err = callAsyncFunctionAndWait(iVssBackupComponents.DoSnapshotSet, "DoSnapshotSet",
timeoutInMillis)
deadline)
if err != nil {
iVssBackupComponents.AbortBackup()
_ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
@ -875,13 +1013,12 @@ func NewVssSnapshot(
var snapshotProperties VssSnapshotProperties
err = iVssBackupComponents.GetSnapshotProperties(snapshotSetID, &snapshotProperties)
if err != nil {
iVssBackupComponents.AbortBackup()
_ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
for mountPoint, info := range mountPointInfo {
if !info.isSnapshotted {
continue
}
@ -900,8 +1037,10 @@ func NewVssSnapshot(
mountPointInfo[mountPoint] = info
}
return VssSnapshot{iVssBackupComponents, snapshotSetID, snapshotProperties,
snapshotProperties.GetSnapshotDeviceObject(), mountPointInfo, timeoutInMillis}, nil
return VssSnapshot{
iVssBackupComponents, snapshotSetID, snapshotProperties,
snapshotProperties.GetSnapshotDeviceObject(), mountPointInfo, time.Until(deadline),
}, nil
}
// Delete deletes the created snapshot.
@ -922,15 +1061,17 @@ func (p *VssSnapshot) Delete() error {
if p.iVssBackupComponents != nil {
defer p.iVssBackupComponents.Release()
deadline := time.Now().Add(p.timeout)
err = callAsyncFunctionAndWait(p.iVssBackupComponents.BackupComplete, "BackupComplete",
p.timeoutInMillis)
deadline)
if err != nil {
return err
}
if _, _, e := p.iVssBackupComponents.DeleteSnapshots(p.snapshotID); e != nil {
err = newVssTextError(fmt.Sprintf("Failed to delete snapshot: %s", e.Error()))
p.iVssBackupComponents.AbortBackup()
_ = p.iVssBackupComponents.AbortBackup()
if err != nil {
return err
}
@ -940,12 +1081,61 @@ func (p *VssSnapshot) Delete() error {
return nil
}
func getProviderID(provider string) (*ole.GUID, error) {
comInterface, err := ole.CreateInstance(CLSID_VSS_COORDINATOR, UIID_IVSS_ADMIN)
if err != nil {
return nil, err
}
defer comInterface.Release()
vssAdmin := (*IVSSAdmin)(unsafe.Pointer(comInterface))
providerLower := strings.ToLower(provider)
switch providerLower {
case "":
return ole.IID_NULL, nil
case "ms":
return ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}"), nil
}
enum, err := vssAdmin.QueryProviders()
if err != nil {
return nil, err
}
defer enum.Release()
id := ole.NewGUID(provider)
var props struct {
objectType uint32
provider VssProviderProperties
}
for {
count, err := enum.Next(1, unsafe.Pointer(&props))
if err != nil {
return nil, err
}
if count < 1 {
return nil, errors.Errorf(`invalid VSS provider "%s"`, provider)
}
name := ole.UTF16PtrToString(props.provider.providerName)
vssFreeProviderProperties(&props.provider)
if id != nil && *id == props.provider.providerID ||
id == nil && providerLower == strings.ToLower(name) {
return &props.provider.providerID, nil
}
}
}
// asyncCallFunc is the callback type for callAsyncFunctionAndWait.
type asyncCallFunc func() (*IVSSAsync, error)
// callAsyncFunctionAndWait calls an async functions and waits for it to either
// finish or timeout.
func callAsyncFunctionAndWait(function asyncCallFunc, name string, timeoutInMillis uint32) error {
func callAsyncFunctionAndWait(function asyncCallFunc, name string, deadline time.Time) error {
iVssAsync, err := function()
if err != nil {
return err
@ -955,7 +1145,12 @@ func callAsyncFunctionAndWait(function asyncCallFunc, name string, timeoutInMill
return newVssTextError(fmt.Sprintf("%s() returned nil", name))
}
err = iVssAsync.WaitUntilAsyncFinished(timeoutInMillis)
timeout := time.Until(deadline)
if timeout <= 0 {
return newVssTextError(fmt.Sprintf("%s() deadline exceeded", name))
}
err = iVssAsync.WaitUntilAsyncFinished(timeout)
iVssAsync.Release()
return err
}
@ -1036,6 +1231,7 @@ func enumerateMountedFolders(volume string) ([]string, error) {
return mountedFolders, nil
}
// nolint:errcheck
defer windows.FindVolumeMountPointClose(handle)
volumeMountPoint := syscall.UTF16ToString(volumeMountPointBuffer)