diff --git a/cmd/restic/cmd_check.go b/cmd/restic/cmd_check.go index b9f3199b2..3c4c9daa9 100644 --- a/cmd/restic/cmd_check.go +++ b/cmd/restic/cmd_check.go @@ -16,6 +16,7 @@ import ( "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/restic" + "github.com/restic/restic/internal/ui" ) var cmdCheck = &cobra.Command{ @@ -97,7 +98,7 @@ func checkFlags(opts CheckOptions) error { } } else { - fileSize, err := parseSizeStr(opts.ReadDataSubset) + fileSize, err := ui.ParseBytes(opts.ReadDataSubset) if err != nil { return argumentError } @@ -363,7 +364,7 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args if repoSize == 0 { return errors.Fatal("Cannot read from a repository having size 0") } - subsetSize, _ := parseSizeStr(opts.ReadDataSubset) + subsetSize, _ := ui.ParseBytes(opts.ReadDataSubset) if subsetSize > repoSize { subsetSize = repoSize } diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go index 1889dffd6..e4c2c7b29 100644 --- a/cmd/restic/cmd_prune.go +++ b/cmd/restic/cmd_prune.go @@ -81,7 +81,7 @@ func addPruneOptions(c *cobra.Command) { func verifyPruneOptions(opts *PruneOptions) error { opts.MaxRepackBytes = math.MaxUint64 if len(opts.MaxRepackSize) > 0 { - size, err := parseSizeStr(opts.MaxRepackSize) + size, err := ui.ParseBytes(opts.MaxRepackSize) if err != nil { return err } @@ -124,7 +124,7 @@ func verifyPruneOptions(opts *PruneOptions) error { } default: - size, err := parseSizeStr(maxUnused) + size, err := ui.ParseBytes(maxUnused) if err != nil { return errors.Fatalf("invalid number of bytes %q for --max-unused: %v", opts.MaxUnused, err) } diff --git a/cmd/restic/exclude.go b/cmd/restic/exclude.go index efe6f41e4..095944610 100644 --- a/cmd/restic/exclude.go +++ b/cmd/restic/exclude.go @@ -7,7 +7,6 @@ import ( "io" "os" "path/filepath" - "strconv" "strings" "sync" @@ -17,6 +16,7 @@ import ( "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/textfile" + "github.com/restic/restic/internal/ui" "github.com/spf13/pflag" ) @@ -364,7 +364,7 @@ func rejectResticCache(repo *repository.Repository) (RejectByNameFunc, error) { } func rejectBySize(maxSizeStr string) (RejectFunc, error) { - maxSize, err := parseSizeStr(maxSizeStr) + maxSize, err := ui.ParseBytes(maxSizeStr) if err != nil { return nil, err } @@ -385,35 +385,6 @@ func rejectBySize(maxSizeStr string) (RejectFunc, error) { }, nil } -func parseSizeStr(sizeStr string) (int64, error) { - if sizeStr == "" { - return 0, errors.New("expected size, got empty string") - } - - numStr := sizeStr[:len(sizeStr)-1] - var unit int64 = 1 - - switch sizeStr[len(sizeStr)-1] { - case 'b', 'B': - // use initialized values, do nothing here - case 'k', 'K': - unit = 1024 - case 'm', 'M': - unit = 1024 * 1024 - case 'g', 'G': - unit = 1024 * 1024 * 1024 - case 't', 'T': - unit = 1024 * 1024 * 1024 * 1024 - default: - numStr = sizeStr - } - value, err := strconv.ParseInt(numStr, 10, 64) - if err != nil { - return 0, err - } - return value * unit, nil -} - // readExcludePatternsFromFiles reads all exclude files and returns the list of // exclude patterns. For each line, leading and trailing white space is removed // and comment lines are ignored. For each remaining pattern, environment diff --git a/cmd/restic/exclude_test.go b/cmd/restic/exclude_test.go index 050a083e4..9a24418ae 100644 --- a/cmd/restic/exclude_test.go +++ b/cmd/restic/exclude_test.go @@ -187,54 +187,6 @@ func TestMultipleIsExcludedByFile(t *testing.T) { } } -func TestParseSizeStr(t *testing.T) { - sizeStrTests := []struct { - in string - expected int64 - }{ - {"1024", 1024}, - {"1024b", 1024}, - {"1024B", 1024}, - {"1k", 1024}, - {"100k", 102400}, - {"100K", 102400}, - {"10M", 10485760}, - {"100m", 104857600}, - {"20G", 21474836480}, - {"10g", 10737418240}, - {"2T", 2199023255552}, - {"2t", 2199023255552}, - } - - for _, tt := range sizeStrTests { - actual, err := parseSizeStr(tt.in) - test.OK(t, err) - - if actual != tt.expected { - t.Errorf("parseSizeStr(%s) = %d; expected %d", tt.in, actual, tt.expected) - } - } -} - -func TestParseInvalidSizeStr(t *testing.T) { - invalidSizes := []string{ - "", - " ", - "foobar", - "zzz", - } - - for _, s := range invalidSizes { - v, err := parseSizeStr(s) - if err == nil { - t.Errorf("wanted error for invalid value %q, got nil", s) - } - if v != 0 { - t.Errorf("wanted zero for invalid value %q, got: %v", s, v) - } - } -} - // TestIsExcludedByFileSize is for testing the instance of // --exclude-larger-than parameters func TestIsExcludedByFileSize(t *testing.T) { diff --git a/internal/ui/format.go b/internal/ui/format.go index 34c97703a..d2e0a4d2b 100644 --- a/internal/ui/format.go +++ b/internal/ui/format.go @@ -3,7 +3,10 @@ package ui import ( "bytes" "encoding/json" + "errors" "fmt" + "math/bits" + "strconv" "time" ) @@ -56,6 +59,44 @@ func FormatSeconds(sec uint64) string { return fmt.Sprintf("%d:%02d", min, sec) } +// ParseBytes parses a size in bytes from s. It understands the suffixes +// B, K, M, G and T for powers of 1024. +func ParseBytes(s string) (int64, error) { + if s == "" { + return 0, errors.New("expected size, got empty string") + } + + numStr := s[:len(s)-1] + var unit uint64 = 1 + + switch s[len(s)-1] { + case 'b', 'B': + // use initialized values, do nothing here + case 'k', 'K': + unit = 1024 + case 'm', 'M': + unit = 1024 * 1024 + case 'g', 'G': + unit = 1024 * 1024 * 1024 + case 't', 'T': + unit = 1024 * 1024 * 1024 * 1024 + default: + numStr = s + } + value, err := strconv.ParseInt(numStr, 10, 64) + if err != nil { + return 0, err + } + + hi, lo := bits.Mul64(uint64(value), unit) + value = int64(lo) + if hi != 0 || value < 0 { + return 0, fmt.Errorf("ParseSize: %q: %w", numStr, strconv.ErrRange) + } + + return value, nil +} + func ToJSONString(status interface{}) string { buf := new(bytes.Buffer) err := json.NewEncoder(buf).Encode(status) diff --git a/internal/ui/format_test.go b/internal/ui/format_test.go index b6a1c13d1..4223d4e20 100644 --- a/internal/ui/format_test.go +++ b/internal/ui/format_test.go @@ -1,6 +1,10 @@ package ui -import "testing" +import ( + "testing" + + "github.com/restic/restic/internal/test" +) func TestFormatBytes(t *testing.T) { for _, c := range []struct { @@ -36,3 +40,47 @@ func TestFormatPercent(t *testing.T) { } } } + +func TestParseBytes(t *testing.T) { + for _, tt := range []struct { + in string + expected int64 + }{ + {"1024", 1024}, + {"1024b", 1024}, + {"1024B", 1024}, + {"1k", 1024}, + {"100k", 102400}, + {"100K", 102400}, + {"10M", 10485760}, + {"100m", 104857600}, + {"20G", 21474836480}, + {"10g", 10737418240}, + {"2T", 2199023255552}, + {"2t", 2199023255552}, + {"9223372036854775807", 1<<63 - 1}, + } { + actual, err := ParseBytes(tt.in) + test.OK(t, err) + test.Equals(t, tt.expected, actual) + } +} + +func TestParseBytesInvalid(t *testing.T) { + for _, s := range []string{ + "", + " ", + "foobar", + "zzz", + "18446744073709551615", // 1<<64-1. + "9223372036854775807k", // 1<<63-1 kiB. + "9999999999999M", + "99999999999999999999", + } { + v, err := ParseBytes(s) + if err == nil { + t.Errorf("wanted error for invalid value %q, got nil", s) + } + test.Equals(t, int64(0), v) + } +}