use standalone shutdown hook for readPasswordTerminal

move terminal restoration into readPasswordTerminal
This commit is contained in:
Michael Eischer 2024-03-29 23:52:45 +01:00
parent 86c7909f41
commit eb710a28e8
7 changed files with 58 additions and 65 deletions

View File

@ -53,7 +53,7 @@ func init() {
} }
func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args []string) error { func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args []string) error {
secondaryGopts, isFromRepo, err := fillSecondaryGlobalOpts(opts.secondaryRepoOptions, gopts, "destination") secondaryGopts, isFromRepo, err := fillSecondaryGlobalOpts(ctx, opts.secondaryRepoOptions, gopts, "destination")
if err != nil { if err != nil {
return err return err
} }

View File

@ -80,7 +80,7 @@ func runInit(ctx context.Context, opts InitOptions, gopts GlobalOptions, args []
return err return err
} }
gopts.password, err = ReadPasswordTwice(gopts, gopts.password, err = ReadPasswordTwice(ctx, gopts,
"enter password for new repository: ", "enter password for new repository: ",
"enter password again: ") "enter password again: ")
if err != nil { if err != nil {
@ -131,7 +131,7 @@ func runInit(ctx context.Context, opts InitOptions, gopts GlobalOptions, args []
func maybeReadChunkerPolynomial(ctx context.Context, opts InitOptions, gopts GlobalOptions) (*chunker.Pol, error) { func maybeReadChunkerPolynomial(ctx context.Context, opts InitOptions, gopts GlobalOptions) (*chunker.Pol, error) {
if opts.CopyChunkerParameters { if opts.CopyChunkerParameters {
otherGopts, _, err := fillSecondaryGlobalOpts(opts.secondaryRepoOptions, gopts, "secondary") otherGopts, _, err := fillSecondaryGlobalOpts(ctx, opts.secondaryRepoOptions, gopts, "secondary")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -60,7 +60,7 @@ func runKeyAdd(ctx context.Context, gopts GlobalOptions, opts KeyAddOptions, arg
} }
func addKey(ctx context.Context, repo *repository.Repository, gopts GlobalOptions, opts KeyAddOptions) error { func addKey(ctx context.Context, repo *repository.Repository, gopts GlobalOptions, opts KeyAddOptions) error {
pw, err := getNewPassword(gopts, opts.NewPasswordFile) pw, err := getNewPassword(ctx, gopts, opts.NewPasswordFile)
if err != nil { if err != nil {
return err return err
} }
@ -83,7 +83,7 @@ func addKey(ctx context.Context, repo *repository.Repository, gopts GlobalOption
// testKeyNewPassword is used to set a new password during integration testing. // testKeyNewPassword is used to set a new password during integration testing.
var testKeyNewPassword string var testKeyNewPassword string
func getNewPassword(gopts GlobalOptions, newPasswordFile string) (string, error) { func getNewPassword(ctx context.Context, gopts GlobalOptions, newPasswordFile string) (string, error) {
if testKeyNewPassword != "" { if testKeyNewPassword != "" {
return testKeyNewPassword, nil return testKeyNewPassword, nil
} }
@ -97,7 +97,7 @@ func getNewPassword(gopts GlobalOptions, newPasswordFile string) (string, error)
newopts := gopts newopts := gopts
newopts.password = "" newopts.password = ""
return ReadPasswordTwice(newopts, return ReadPasswordTwice(ctx, newopts,
"enter new password: ", "enter new password: ",
"enter password again: ") "enter password again: ")
} }

View File

@ -57,7 +57,7 @@ func runKeyPasswd(ctx context.Context, gopts GlobalOptions, opts KeyPasswdOption
} }
func changePassword(ctx context.Context, repo *repository.Repository, gopts GlobalOptions, opts KeyPasswdOptions) error { func changePassword(ctx context.Context, repo *repository.Repository, gopts GlobalOptions, opts KeyPasswdOptions) error {
pw, err := getNewPassword(gopts, opts.NewPasswordFile) pw, err := getNewPassword(ctx, gopts, opts.NewPasswordFile)
if err != nil { if err != nil {
return err return err
} }

View File

@ -96,7 +96,6 @@ var globalOptions = GlobalOptions{
stderr: os.Stderr, stderr: os.Stderr,
} }
var isReadingPassword bool
var internalGlobalCtx context.Context var internalGlobalCtx context.Context
func init() { func init() {
@ -165,8 +164,6 @@ func init() {
// parse target pack size from env, on error the default value will be used // parse target pack size from env, on error the default value will be used
targetPackSize, _ := strconv.ParseUint(os.Getenv("RESTIC_PACK_SIZE"), 10, 32) targetPackSize, _ := strconv.ParseUint(os.Getenv("RESTIC_PACK_SIZE"), 10, 32)
globalOptions.PackSize = uint(targetPackSize) globalOptions.PackSize = uint(targetPackSize)
restoreTerminal()
} }
func stdinIsTerminal() bool { func stdinIsTerminal() bool {
@ -191,40 +188,6 @@ func stdoutTerminalWidth() int {
return w return w
} }
// restoreTerminal installs a cleanup handler that restores the previous
// terminal state on exit. This handler is only intended to restore the
// terminal configuration if restic exits after receiving a signal. A regular
// program execution must revert changes to the terminal configuration itself.
// The terminal configuration is only restored while reading a password.
func restoreTerminal() {
if !term.IsTerminal(int(os.Stdout.Fd())) {
return
}
fd := int(os.Stdout.Fd())
state, err := term.GetState(fd)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err)
return
}
AddCleanupHandler(func(code int) (int, error) {
// Restoring the terminal configuration while restic runs in the
// background, causes restic to get stopped on unix systems with
// a SIGTTOU signal. Thus only restore the terminal settings if
// they might have been modified, which is the case while reading
// a password.
if !isReadingPassword {
return code, nil
}
err := term.Restore(fd, state)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err)
}
return code, err
})
}
// ClearLine creates a platform dependent string to clear the current // ClearLine creates a platform dependent string to clear the current
// line, so it can be overwritten. // line, so it can be overwritten.
// //
@ -333,24 +296,48 @@ func readPassword(in io.Reader) (password string, err error) {
// readPasswordTerminal reads the password from the given reader which must be a // readPasswordTerminal reads the password from the given reader which must be a
// tty. Prompt is printed on the writer out before attempting to read the // tty. Prompt is printed on the writer out before attempting to read the
// password. // password. If the context is canceled, the function leaks the password reading
func readPasswordTerminal(in *os.File, out io.Writer, prompt string) (password string, err error) { // goroutine.
func readPasswordTerminal(ctx context.Context, in *os.File, out *os.File, prompt string) (password string, err error) {
fd := int(out.Fd())
state, err := term.GetState(fd)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err)
return "", err
}
done := make(chan struct{})
var buf []byte
go func() {
defer close(done)
fmt.Fprint(out, prompt) fmt.Fprint(out, prompt)
isReadingPassword = true buf, err = term.ReadPassword(int(in.Fd()))
buf, err := term.ReadPassword(int(in.Fd()))
isReadingPassword = false
fmt.Fprintln(out) fmt.Fprintln(out)
}()
select {
case <-ctx.Done():
err := term.Restore(fd, state)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err)
}
return "", ctx.Err()
case <-done:
// clean shutdown, nothing to do
}
if err != nil { if err != nil {
return "", errors.Wrap(err, "ReadPassword") return "", errors.Wrap(err, "ReadPassword")
} }
password = string(buf) return string(buf), nil
return password, nil
} }
// ReadPassword reads the password from a password file, the environment // ReadPassword reads the password from a password file, the environment
// variable RESTIC_PASSWORD or prompts the user. // variable RESTIC_PASSWORD or prompts the user. If the context is canceled,
func ReadPassword(opts GlobalOptions, prompt string) (string, error) { // the function leaks the password reading goroutine.
func ReadPassword(ctx context.Context, opts GlobalOptions, prompt string) (string, error) {
if opts.password != "" { if opts.password != "" {
return opts.password, nil return opts.password, nil
} }
@ -361,7 +348,7 @@ func ReadPassword(opts GlobalOptions, prompt string) (string, error) {
) )
if stdinIsTerminal() { if stdinIsTerminal() {
password, err = readPasswordTerminal(os.Stdin, os.Stderr, prompt) password, err = readPasswordTerminal(ctx, os.Stdin, os.Stderr, prompt)
} else { } else {
password, err = readPassword(os.Stdin) password, err = readPassword(os.Stdin)
Verbosef("reading repository password from stdin\n") Verbosef("reading repository password from stdin\n")
@ -379,14 +366,15 @@ func ReadPassword(opts GlobalOptions, prompt string) (string, error) {
} }
// ReadPasswordTwice calls ReadPassword two times and returns an error when the // ReadPasswordTwice calls ReadPassword two times and returns an error when the
// passwords don't match. // passwords don't match. If the context is canceled, the function leaks the
func ReadPasswordTwice(gopts GlobalOptions, prompt1, prompt2 string) (string, error) { // password reading goroutine.
pw1, err := ReadPassword(gopts, prompt1) func ReadPasswordTwice(ctx context.Context, gopts GlobalOptions, prompt1, prompt2 string) (string, error) {
pw1, err := ReadPassword(ctx, gopts, prompt1)
if err != nil { if err != nil {
return "", err return "", err
} }
if stdinIsTerminal() { if stdinIsTerminal() {
pw2, err := ReadPassword(gopts, prompt2) pw2, err := ReadPassword(ctx, gopts, prompt2)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -469,7 +457,10 @@ func OpenRepository(ctx context.Context, opts GlobalOptions) (*repository.Reposi
} }
for ; passwordTriesLeft > 0; passwordTriesLeft-- { for ; passwordTriesLeft > 0; passwordTriesLeft-- {
opts.password, err = ReadPassword(opts, "enter password for repository: ") opts.password, err = ReadPassword(ctx, opts, "enter password for repository: ")
if ctx.Err() != nil {
return nil, ctx.Err()
}
if err != nil && passwordTriesLeft > 1 { if err != nil && passwordTriesLeft > 1 {
opts.password = "" opts.password = ""
fmt.Printf("%s. Try again\n", err) fmt.Printf("%s. Try again\n", err)

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"os" "os"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
@ -56,7 +57,7 @@ func initSecondaryRepoOptions(f *pflag.FlagSet, opts *secondaryRepoOptions, repo
opts.PasswordCommand = os.Getenv("RESTIC_FROM_PASSWORD_COMMAND") opts.PasswordCommand = os.Getenv("RESTIC_FROM_PASSWORD_COMMAND")
} }
func fillSecondaryGlobalOpts(opts secondaryRepoOptions, gopts GlobalOptions, repoPrefix string) (GlobalOptions, bool, error) { func fillSecondaryGlobalOpts(ctx context.Context, opts secondaryRepoOptions, gopts GlobalOptions, repoPrefix string) (GlobalOptions, bool, error) {
if opts.Repo == "" && opts.RepositoryFile == "" && opts.LegacyRepo == "" && opts.LegacyRepositoryFile == "" { if opts.Repo == "" && opts.RepositoryFile == "" && opts.LegacyRepo == "" && opts.LegacyRepositoryFile == "" {
return GlobalOptions{}, false, errors.Fatal("Please specify a source repository location (--from-repo or --from-repository-file)") return GlobalOptions{}, false, errors.Fatal("Please specify a source repository location (--from-repo or --from-repository-file)")
} }
@ -109,7 +110,7 @@ func fillSecondaryGlobalOpts(opts secondaryRepoOptions, gopts GlobalOptions, rep
return GlobalOptions{}, false, err return GlobalOptions{}, false, err
} }
} }
dstGopts.password, err = ReadPassword(dstGopts, "enter password for "+repoPrefix+" repository: ") dstGopts.password, err = ReadPassword(ctx, dstGopts, "enter password for "+repoPrefix+" repository: ")
if err != nil { if err != nil {
return GlobalOptions{}, false, err return GlobalOptions{}, false, err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -170,7 +171,7 @@ func TestFillSecondaryGlobalOpts(t *testing.T) {
// Test all valid cases // Test all valid cases
for _, testCase := range validSecondaryRepoTestCases { for _, testCase := range validSecondaryRepoTestCases {
DstGOpts, isFromRepo, err := fillSecondaryGlobalOpts(testCase.Opts, gOpts, "destination") DstGOpts, isFromRepo, err := fillSecondaryGlobalOpts(context.TODO(), testCase.Opts, gOpts, "destination")
rtest.OK(t, err) rtest.OK(t, err)
rtest.Equals(t, DstGOpts, testCase.DstGOpts) rtest.Equals(t, DstGOpts, testCase.DstGOpts)
rtest.Equals(t, isFromRepo, testCase.FromRepo) rtest.Equals(t, isFromRepo, testCase.FromRepo)
@ -178,7 +179,7 @@ func TestFillSecondaryGlobalOpts(t *testing.T) {
// Test all invalid cases // Test all invalid cases
for _, testCase := range invalidSecondaryRepoTestCases { for _, testCase := range invalidSecondaryRepoTestCases {
_, _, err := fillSecondaryGlobalOpts(testCase.Opts, gOpts, "destination") _, _, err := fillSecondaryGlobalOpts(context.TODO(), testCase.Opts, gOpts, "destination")
rtest.Assert(t, err != nil, "Expected error, but function did not return an error") rtest.Assert(t, err != nil, "Expected error, but function did not return an error")
} }
} }