diff --git a/Makefile b/Makefile index f563cee1..4c771dcc 100644 --- a/Makefile +++ b/Makefile @@ -178,9 +178,9 @@ build-windows: prepare_build ## Build the binary for Windows build-all: build-mac build-linux build-pi build-windows ## Build the binary for all platforms -test: prepare_test ## Run unit tests +test: $(GOBIN)/gotestsum prepare_test ## Run unit tests @echo "[*] $@" - $(GOTEST) $(TESTS) + $(GOBIN)/gotestsum $(TESTS) test-ci: $(GOBIN)/gotestsum prepare_test ## Run unit tests with coverage (for CI) @echo "[*] $@" diff --git a/codecov.yml b/codecov.yml index 98e325fa..95e26d52 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,6 +4,7 @@ ignore: - priority/check/ - util/test_executable/ - term/remote.go + - term/test/ - deprecation.go - logger.go - systemd.go diff --git a/commands.go b/commands.go index 9b101381..d426a05b 100644 --- a/commands.go +++ b/commands.go @@ -22,7 +22,7 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/platform" "github.com/creativeprojects/resticprofile/remote" - "github.com/creativeprojects/resticprofile/term" + "github.com/creativeprojects/resticprofile/util/ansi" "github.com/creativeprojects/resticprofile/win" "github.com/distatus/battery" ) @@ -190,11 +190,11 @@ func getOwnCommands() []ownCommand { } } -func panicCommand(_ io.Writer, _ commandContext) error { +func panicCommand(_ commandContext) error { panic("you asked for it") } -func completeCommand(output io.Writer, ctx commandContext) error { +func completeCommand(ctx commandContext) error { args := ctx.request.arguments requester := "unknown" requesterVersion := 0 @@ -226,13 +226,13 @@ func completeCommand(output io.Writer, ctx commandContext) error { completions := NewCompleter(ctx.ownCommands.All(), DefaultFlagsLoader, includeDescription).Complete(args) if len(completions) > 0 { for _, completion := range completions { - fmt.Fprintln(output, completion) + ctx.terminal.Println(completion) } } return nil } -func showProfileOrGroup(output io.Writer, ctx commandContext) error { +func showProfileOrGroup(ctx commandContext) error { c := ctx.config flags := ctx.flags @@ -270,21 +270,21 @@ func showProfileOrGroup(output io.Writer, ctx commandContext) error { } // Show global - err = config.ShowStruct(output, global, constants.SectionConfigurationGlobal) + err = config.ShowStruct(ctx.terminal, global, constants.SectionConfigurationGlobal) if err != nil { clog.Errorf("cannot show global section: %s", err.Error()) } - _, _ = fmt.Fprintln(output) + _, _ = ctx.terminal.Println() // Show profile or group - err = config.ShowStruct(output, profileOrGroup, profileOrGroup.Kind()+" "+flags.name) + err = config.ShowStruct(ctx.terminal, profileOrGroup, profileOrGroup.Kind()+" "+flags.name) if err != nil { clog.Errorf("cannot show profile or group '%s': %s", flags.name, err.Error()) } - _, _ = fmt.Fprintln(output) + _, _ = ctx.terminal.Println() // Show schedules - showSchedules(output, slices.Collect(maps.Values(profileOrGroup.Schedules()))) + showSchedules(ctx.terminal, slices.Collect(maps.Values(profileOrGroup.Schedules()))) if profile, ok := profileOrGroup.(*config.Profile); ok { // Show deprecation notice @@ -309,7 +309,7 @@ func showSchedules(output io.Writer, schedules []*config.Schedule) { } // randomKey simply display a base64'd random key to the console -func randomKey(output io.Writer, ctx commandContext) error { +func randomKey(ctx commandContext) error { var err error flags := ctx.flags size := uint64(1024) @@ -329,19 +329,20 @@ func randomKey(output io.Writer, ctx commandContext) error { if err != nil { return err } - encoder := base64.NewEncoder(base64.StdEncoding, output) + encoder := base64.NewEncoder(base64.StdEncoding, ctx.terminal) _, err = encoder.Write(buffer) encoder.Close() - fmt.Fprintln(output, "") + ctx.terminal.Println() return err } -func testElevationCommand(_ io.Writer, ctx commandContext) error { +func testElevationCommand(ctx commandContext) error { if ctx.flags.isChild { client := remote.NewClient(ctx.flags.parentPort) - term.Print("first line", "\n") - term.Println("second", "one") - term.Printf("value = %d\n", 11) + ctx.terminal.Print("simple line", "\n") + ctx.terminal.Printf("value = %d\n", 11) + ctx.terminal.Println(ansi.Bold("in bold")) + clog.Info("log content") err := client.Done() if err != nil { return err @@ -395,7 +396,7 @@ func elevated() error { return nil } -func batteryCommand(stdout io.Writer, _ commandContext) error { +func batteryCommand(ctx commandContext) error { all, err := batt.Batteries() if err != nil { if errors.Is(err, battery.ErrFatal{}) { @@ -407,13 +408,13 @@ func batteryCommand(stdout io.Writer, _ commandContext) error { clog.Info("no battery detected") return nil } - fmt.Fprintln(stdout, "") - w := tabwriter.NewWriter(stdout, 2, 2, 2, ' ', tabwriter.AlignRight) + ctx.terminal.Println() + w := tabwriter.NewWriter(ctx.terminal, 2, 2, 2, ' ', tabwriter.AlignRight) fmt.Fprintf(w, "Battery\tStatus\tCurrent capacity\tFull capacity\tDesign Capacity\tCharge Rate\tVoltage\t\n") for index, batt := range all { fmt.Fprintf(w, "#%d\t%s\t%.2f mWh\t%.2f mWh\t%.2f mWh\t%.2f mWh\t%.2f V\t\n", index, batt.State, batt.Current, batt.Full, batt.Design, batt.ChargeRate, batt.Voltage) } w.Flush() - fmt.Fprintln(stdout, "") + ctx.terminal.Println() return nil } diff --git a/commands_display.go b/commands_display.go index 5737cdf7..9b05262c 100644 --- a/commands_display.go +++ b/commands_display.go @@ -17,30 +17,18 @@ import ( "github.com/creativeprojects/resticprofile/shell" "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/ansi" "github.com/creativeprojects/resticprofile/util/collect" - "github.com/fatih/color" - colorable "github.com/mattn/go-colorable" ) -var ( - ansiBold = color.New(color.Bold).SprintFunc() - ansiCyan = color.New(color.FgCyan).SprintFunc() - ansiYellow = color.New(color.FgYellow).SprintFunc() -) - -func displayWriter(output io.Writer, flags commandLineFlags) (out func(args ...any) io.Writer, closer func()) { - if term.GetOutput() == output { - output = term.GetColorableOutput() - - if width, _ := term.OsStdoutTerminalSize(); width > 10 { - output = newLineLengthWriter(output, width) +func displayWriter(terminal *term.Terminal) (out func(args ...any) io.Writer, closer func()) { + output := terminal.Stdout() + if terminal.StdoutIsTerminal() { + if width, _ := term.Size(); width > 10 { + output = ansi.NewLineLengthWriter(terminal, width) } } - if flags.noAnsi { - output = colorable.NewNonColorable(output) - } - w := tabwriter.NewWriter(output, 0, 0, 2, ' ', 0) out = func(args ...any) io.Writer { @@ -70,12 +58,12 @@ func getCommonUsageHelpLine(commandName string, withProfile bool) string { } return fmt.Sprintf( "%s [resticprofile flags] %s%s", - ansiBold("resticprofile"), profile, ansiBold(commandName), + ansi.Bold("resticprofile"), profile, ansi.Bold(commandName), ) } -func displayOwnCommands(output io.Writer, ctx commandContext) { - out, closer := displayWriter(output, ctx.flags) +func displayOwnCommands(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() for _, command := range ctx.ownCommands.commands { @@ -87,8 +75,8 @@ func displayOwnCommands(output io.Writer, ctx commandContext) { } } -func displayOwnCommandHelp(output io.Writer, commandName string, ctx commandContext) { - out, closer := displayWriter(output, ctx.flags) +func displayOwnCommandHelp(ctx commandContext, commandName string) { + out, closer := displayWriter(ctx.terminal) defer closer() var command *ownCommand @@ -131,8 +119,8 @@ func displayOwnCommandHelp(output io.Writer, commandName string, ctx commandCont } } -func displayCommonUsageHelp(output io.Writer, ctx commandContext) { - out, closer := displayWriter(output, ctx.flags) +func displayCommonUsageHelp(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() out("resticprofile is a configuration profiles manager for backup profiles and ") @@ -142,26 +130,26 @@ func displayCommonUsageHelp(output io.Writer, ctx commandContext) { out("\t%s [restic flags]\n", getCommonUsageHelpLine("restic-command", true)) out("\t%s [command specific flags]\n", getCommonUsageHelpLine("resticprofile-command", true)) out("\n") - out(ansiBold("resticprofile flags:\n")) + out(ansi.Bold("resticprofile flags:\n")) out(ctx.flags.usagesHelp) out("\n\n") - out(ansiBold("resticprofile own commands:\n")) - displayOwnCommands(out(), ctx) + out(ansi.Bold("resticprofile own commands:\n")) + displayOwnCommands(ctx) out("\n") out("%s at %s\n", - ansiBold("Documentation available"), - ansiBold(ansiCyan("https://creativeprojects.github.io/resticprofile/"))) + ansi.Bold("Documentation available"), + ansi.Bold(ansi.Cyan("https://creativeprojects.github.io/resticprofile/"))) out("\n") } -func displayResticHelp(output io.Writer, configuration *config.Config, flags commandLineFlags, command string) { - out, closer := displayWriter(output, flags) +func displayResticHelp(ctx commandContext, command string) { + out, closer := displayWriter(ctx.terminal) defer closer() resticBinary := "" - if configuration != nil { - if section, err := configuration.GetGlobalSection(); err == nil { + if ctx.config != nil { + if section, err := ctx.config.GetGlobalSection(); err == nil { resticBinary = section.ResticBinary } } @@ -187,11 +175,11 @@ func displayResticHelp(output io.Writer, configuration *config.Config, flags com out("restic binary not found: %s\n", err.Error()) } - if configuration != nil { - out("\nFlags applied by resticprofile (configuration \"%s\"):\n\n", ansiBold(configuration.GetConfigFile())) + if ctx.config != nil { + out("\nFlags applied by resticprofile (configuration \"%s\"):\n\n", ansi.Bold(ctx.config.GetConfigFile())) - if profileNames := configuration.GetProfileNames(); len(profileNames) > 0 { - profiles := configuration.GetProfiles() + if profileNames := ctx.config.GetProfileNames(); len(profileNames) > 0 { + profiles := ctx.config.GetProfiles() sort.Strings(profileNames) unescaper := strings.NewReplacer( `\\`, `^^`, @@ -202,14 +190,14 @@ func displayResticHelp(output io.Writer, configuration *config.Config, flags com ) for _, name := range profileNames { - out("\tprofile \"%s\":", ansiBold(name)) + out("\tprofile \"%s\":", ansi.Bold(name)) profile := profiles[name] cmdFlags := config.GetNonConfidentialArgs(profile, profile.GetCommandFlags(command)) for _, flag := range cmdFlags.GetAll() { if strings.HasPrefix(flag, "-") { out("\n\t\t") } - out("%s\t", ansiCyan(unescaper.Replace(flag))) + out("%s\t", ansi.Cyan(unescaper.Replace(flag))) } out("\n") } @@ -219,12 +207,9 @@ func displayResticHelp(output io.Writer, configuration *config.Config, flags com } } -func displayHelpCommand(output io.Writer, ctx commandContext) error { +func displayHelpCommand(ctx commandContext) error { flags := ctx.flags - out, closer := displayWriter(output, ctx.flags) - defer closer() - if flags.log == "" { clog.GetDefaultLogger().SetHandler(clog.NewDiscardHandler()) // disable log output } @@ -238,23 +223,23 @@ func displayHelpCommand(output io.Writer, ctx commandContext) error { } if helpForCommand == nil { - displayCommonUsageHelp(out("\n"), ctx) + displayCommonUsageHelp(ctx) } else if ctx.ownCommands.Exists(*helpForCommand, true) || ctx.ownCommands.Exists(*helpForCommand, false) { - displayOwnCommandHelp(out("\n"), *helpForCommand, ctx) + displayOwnCommandHelp(ctx, *helpForCommand) } else { - displayResticHelp(out(), ctx.config, flags, *helpForCommand) + displayResticHelp(ctx, *helpForCommand) } return nil } -func displayVersion(output io.Writer, ctx commandContext) error { - out, closer := displayWriter(output, ctx.flags) +func displayVersion(ctx commandContext) error { + out, closer := displayWriter(ctx.terminal) defer closer() - out("resticprofile version %s commit %s\n", ansiBold(version), ansiYellow(commit)) + out("resticprofile version %s commit %s\n", ansi.Bold(version), ansi.Yellow(commit)) // allow for the general verbose flag, or specified after the command arguments := ctx.request.arguments @@ -293,53 +278,53 @@ func displayVersion(output io.Writer, ctx commandContext) error { out("\n") out("\t%s:\n", "go modules") for _, dep := range bi.Deps { - out("\t\t%s\t%s\n", ansiCyan(dep.Path), dep.Version) + out("\t\t%s\t%s\n", ansi.Cyan(dep.Path), dep.Version) } out("\n") } return nil } -func displayProfilesCommand(output io.Writer, ctx commandContext) error { - displayProfiles(output, ctx.config, ctx.flags) - displayGroups(output, ctx.config, ctx.flags) +func displayProfilesCommand(ctx commandContext) error { + displayProfiles(ctx) + displayGroups(ctx) return nil } -func displayProfiles(output io.Writer, configuration *config.Config, flags commandLineFlags) { - out, closer := displayWriter(output, flags) +func displayProfiles(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() - profiles := configuration.GetProfiles() + profiles := ctx.config.GetProfiles() keys := sortedProfileKeys(profiles) if len(profiles) == 0 { out("\nThere's no available profile in the configuration\n") } else { - out("\n%s (name, sections, description):\n", ansiBold("Profiles available")) + out("\n%s (name, sections, description):\n", ansi.Bold("Profiles available")) for _, name := range keys { sections := profiles[name].DefinedCommands() sort.Strings(sections) if len(sections) == 0 { out("\t%s:\t(n/a)\t%s\n", name, profiles[name].Description) } else { - out("\t%s:\t(%s)\t%s\n", name, ansiCyan(strings.Join(sections, ", ")), profiles[name].Description) + out("\t%s:\t(%s)\t%s\n", name, ansi.Cyan(strings.Join(sections, ", ")), profiles[name].Description) } } } out("\n") } -func displayGroups(output io.Writer, configuration *config.Config, flags commandLineFlags) { - out, closer := displayWriter(output, flags) +func displayGroups(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() - groups := configuration.GetProfileGroups() + groups := ctx.config.GetProfileGroups() if len(groups) == 0 { return } - out("%s (name, profiles, description):\n", ansiBold("Groups available")) + out("%s (name, profiles, description):\n", ansi.Bold("Groups available")) for name, groupList := range groups { - out("\t%s:\t[%s]\t%s\n", name, ansiCyan(strings.Join(groupList.Profiles, ", ")), groupList.Description) + out("\t%s:\t[%s]\t%s\n", name, ansi.Cyan(strings.Join(groupList.Profiles, ", ")), groupList.Description) } out("\n") } @@ -352,96 +337,3 @@ func sortedProfileKeys(data map[string]*config.Profile) []string { sort.Strings(keys) return keys } - -// lineLengthWriter limits the max line length, adding line breaks ('\n') as needed. -// the writer detects the right most column (consecutive whitespace) and aligns content if possible. -type lineLengthWriter struct { - writer io.Writer - tokens []byte - maxLineLength, lastWhite, breakLength, lineLength int - ansiLength, lastWhiteAnsiLength int -} - -func newLineLengthWriter(writer io.Writer, maxLineLength int) *lineLengthWriter { - return &lineLengthWriter{writer: writer, maxLineLength: maxLineLength} -} - -func (l *lineLengthWriter) Write(p []byte) (n int, err error) { - var written int - inAnsi := false - offset := l.lineLength - lineLength := func() int { return l.lineLength - l.ansiLength } - - if len(l.tokens) == 0 { - l.tokens = []byte{' ', '\n'} - } - - for i := 0; i < len(p); i++ { - l.lineLength++ - ws := p[i] == l.tokens[0] // ' ' - br := p[i] == l.tokens[1] // '\n' - - // don't count ansi control sequences - if inAnsi = inAnsi || p[i] == 0x1b; inAnsi { - inAnsi = p[i] != 'm' - l.ansiLength++ - continue - } - - if !br && lineLength() > l.maxLineLength && l.lastWhite-offset > 0 { - lastWhiteIndex := l.lastWhite - offset - 1 - remainder := i - lastWhiteIndex - - if written, err = l.writer.Write(p[:lastWhiteIndex]); err == nil { - p = p[lastWhiteIndex+1:] - i = remainder - 1 - n += written + 1 - - _, _ = l.writer.Write(l.tokens[1:]) // write break (instead of WS at lastWhiteIndex) - for range l.breakLength { - _, _ = l.writer.Write(l.tokens[0:1]) // fill spaces for alignment - } - - l.lineLength = l.breakLength + remainder - l.lastWhite = l.breakLength - offset = l.breakLength - - l.ansiLength -= l.lastWhiteAnsiLength - l.lastWhiteAnsiLength = 0 - } else { - return - } - } - - if ws { - if l.lastWhite == l.lineLength-1 && lineLength() < l.maxLineLength*2/3 { - l.breakLength = lineLength() - } - l.lastWhite = l.lineLength - l.lastWhiteAnsiLength = l.ansiLength - - } else if br { - if written, err = l.writer.Write(p[:i+1]); err == nil { - p = p[i+1:] - i = -1 - n += written - - l.lineLength = 0 - l.lastWhite = 0 - l.breakLength = 0 - offset = 0 - - l.ansiLength = 0 - l.lastWhiteAnsiLength = 0 - } else { - return - } - } - } - - // write remainder - if written, err = l.writer.Write(p); err == nil { - n += written - } - return -} diff --git a/commands_display_test.go b/commands_display_test.go index ef27e1cf..1c481048 100644 --- a/commands_display_test.go +++ b/commands_display_test.go @@ -2,11 +2,11 @@ package main import ( "bytes" - "fmt" "runtime" "strings" "testing" + "github.com/creativeprojects/resticprofile/term" "github.com/fatih/color" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,290 +21,69 @@ var ansiColor = func() (c *color.Color) { var colored = ansiColor.SprintFunc() func TestDisplayWriter(t *testing.T) { - flags, noAnsiFlags := commandLineFlags{}, commandLineFlags{noAnsi: true} + write := func(v ...any) string { + buffer := new(bytes.Buffer) + recorder, err := term.NewRecorder(buffer) + require.NoError(t, err) - buffer := bytes.Buffer{} - write := func(clf commandLineFlags, v ...any) string { - buffer.Reset() - out, closer := displayWriter(&buffer, clf) + terminal := term.NewTerminal(term.WithStdoutRecorder(recorder), term.WithColors(true)) + out, closer := displayWriter(terminal) out(v...) closer() + recorder.Close() return buffer.String() } t.Run("write-plain", func(t *testing.T) { - actual := write(flags, "hello %s %d") + actual := write("hello %s %d") assert.Equal(t, "hello %s %d", actual) }) t.Run("write-with-format", func(t *testing.T) { - actual := write(flags, "hello %s %02d", "world", 5) + actual := write("hello %s %02d", "world", 5) assert.Equal(t, "hello world 05", actual) }) t.Run("write-tabs", func(t *testing.T) { - actual := write(flags, "col1\tcol2\tcol3\nvalue1\tvalue2\tvalue3") + actual := write("col1\tcol2\tcol3\nvalue1\tvalue2\tvalue3") assert.Equal(t, "col1 col2 col3\nvalue1 value2 value3", actual) }) - t.Run("no-ansi", func(t *testing.T) { - actual := write(flags, colored("test")) + t.Run("ansi", func(t *testing.T) { + actual := write(colored("test")) assert.Equal(t, colored("test"), actual) - - actual = write(noAnsiFlags, colored("test")) - assert.Equal(t, "test", actual) + assert.Contains(t, colored("test"), "\x1b[") }) } -func TestLineLengthWriter(t *testing.T) { - tests := []struct { - input, expected string - chunks, scale int - }{ - // test non-breakable - {input: strings.Repeat("-", 50), expected: strings.Repeat("-", 50), chunks: 15}, - - // test breakable without columns - { - input: strings.Repeat("word ", 20), - expected: "" + - strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + - strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + - strings.Repeat("word ", 4), - chunks: 5, scale: 6, - }, - - // test breakable with ANSI color - { - input: strings.Repeat(colored("word "), 20), - expected: "" + - strings.Repeat(colored("word "), 7) + colored("word\n") + - strings.Repeat(colored("word "), 7) + colored("word\n") + - strings.Repeat(colored("word "), 4), - }, - - // test breakable with 2 columns - { - input: "word word word word " + - strings.Repeat("word ", 20), - expected: "" + - "word word word word " + - "word word word\n" + - strings.Repeat(" word word word\n", 5) + - " word word ", - chunks: 3, scale: 15, - }, - - // test breakable with 2 columns and ANSI color - { - input: colored("word word word word ") + - strings.Repeat(colored("word "), 20), - expected: "" + - colored("word word word word ") + - colored("word ") + colored("word ") + colored("word\n ") + - strings.Repeat(colored("word ")+colored("word ")+colored("word\n "), 5) + - colored("word ") + colored("word "), - }, - - // test real-world content - { - input: ` -Usage of resticprofile: - resticprofile [resticprofile flags] [profile name.][restic command] [restic flags] - resticprofile [resticprofile flags] [profile name.][resticprofile command] [command specific flags] - -resticprofile flags: - -c, --config string configuration file (default "profiles") - --dry-run display the restic commands instead of running them - -f, --format string file format of the configuration (default is to use the file extension) - -h, --help display this help - --lock-wait duration wait up to duration to acquire a lock (syntax "1h5m30s") - -l, --log string logs to a target instead of the console - -n, --name string profile name (default "default") - --no-ansi disable ansi control characters (disable console colouring) - --no-lock skip profile lock file - --no-prio don't set any priority on load: used when started from a service that has already set the priority - -q, --quiet display only warnings and errors - --theme string console colouring theme (dark, light, none) (default "light") - --trace display even more debugging information - -v, --verbose display some debugging information - -w, --wait wait at the end until the user presses the enter key - - -resticprofile own commands: - help display help (run in verbose mode for detailed information) - version display version (run in verbose mode for detailed information) - self-update update to latest resticprofile (use -q/--quiet flag to update without confirmation) - profiles display profile names from the configuration file - show show all the details of the current profile - schedule schedule jobs from a profile (use --all flag to schedule all jobs of all profiles) - unschedule remove scheduled jobs of a profile (use --all flag to unschedule all profiles) - status display the status of scheduled jobs (use --all flag for all profiles) - generate generate resources (--random-key [size], --bash-completion & --zsh-completion) - -Documentation available at https://creativeprojects.github.io/resticprofile/ -`, - expected: ` -Usage of resticprofile: - resticprofile [resticprofile flags] - [profile name.][restic command] - [restic flags] - resticprofile [resticprofile flags] - [profile name.][resticprofile - command] [command specific flags] +func TestDisplayWriterNoColors(t *testing.T) { + write := func(v ...any) string { + buffer := new(bytes.Buffer) + recorder, err := term.NewRecorder(buffer) + require.NoError(t, err) -resticprofile flags: - -c, --config string - configuration - file (default - "profiles") - --dry-run display - the restic - commands - instead of - running them - -f, --format string file - format of the - configuration - (default is to - use the file - extension) - -h, --help display - this help - --lock-wait duration wait up to - duration to acquire a lock - (syntax "1h5m30s") - -l, --log string logs to a - target instead - of the console - -n, --name string profile - name (default - "default") - --no-ansi disable - ansi control - characters - (disable - console - colouring) - --no-lock skip - profile lock - file - --no-prio don't set - any priority - on load: used - when started - from a service - that has - already set - the priority - -q, --quiet display - only warnings - and errors - --theme string console - colouring - theme (dark, - light, none) - (default - "light") - --trace display - even more - debugging - information - -v, --verbose display - some debugging - information - -w, --wait wait at - the end until - the user - presses the - enter key - - -resticprofile own commands: - help display help (run in - verbose mode for - detailed information) - version display version (run - in verbose mode for - detailed information) - self-update update to latest - resticprofile (use - -q/--quiet flag to - update without - confirmation) - profiles display profile names - from the configuration - file - show show all the details - of the current profile - schedule schedule jobs from a - profile (use --all - flag to schedule all - jobs of all profiles) - unschedule remove scheduled jobs - of a profile (use - --all flag to - unschedule all - profiles) - status display the status of - scheduled jobs (use - --all flag for all - profiles) - generate generate resources - (--random-key [size], - --bash-completion & - --zsh-completion) - -Documentation available at -https://creativeprojects.github.io/resticprofile/ -`, - }, - } - - for i, test := range tests { - if test.scale == 0 { - test.scale = 1 - } - for chunkSize := 0; chunkSize <= test.chunks; chunkSize++ { - t.Run(fmt.Sprintf("%d-%d", i, chunkSize), func(t *testing.T) { - buffer := bytes.Buffer{} - writer := newLineLengthWriter(&buffer, 40) - input := []byte(test.input) - - var ( - n int - err error - ) - switch chunkSize { - case 0: - n, err = writer.Write(input) - assert.Equal(t, len(input), n) - default: - for len(input) > 0 && err == nil { - length := min(test.scale*chunkSize, len(input)) - n, err = writer.Write(input[:length]) - assert.Equal(t, length, n) - input = input[length:] - } - } - - assert.Nil(t, err) - assert.Equal(t, test.expected, buffer.String()) - }) - } + terminal := term.NewTerminal(term.WithStdoutRecorder(recorder), term.WithColors(false)) + out, closer := displayWriter(terminal) + out(v...) + closer() + recorder.Close() + return buffer.String() } + actual := write(colored("test")) + assert.Equal(t, "test", actual) + assert.Contains(t, colored("test"), "\x1b[") } func TestDisplayVersionVerbose1(t *testing.T) { buffer := &bytes.Buffer{} - err := displayVersion(buffer, commandContext{Context: Context{flags: commandLineFlags{verbose: true}}}) + err := displayVersion(commandContext{Context: Context{terminal: term.NewTerminal(term.WithStdout(buffer)), flags: commandLineFlags{verbose: true}}}) require.NoError(t, err) assert.True(t, strings.Contains(buffer.String(), runtime.GOOS)) } func TestDisplayVersionVerbose2(t *testing.T) { buffer := &bytes.Buffer{} - err := displayVersion(buffer, commandContext{Context: Context{request: Request{arguments: []string{"-v"}}}}) + err := displayVersion(commandContext{Context: Context{terminal: term.NewTerminal(term.WithStdout(buffer)), request: Request{arguments: []string{"-v"}}}}) require.NoError(t, err) assert.True(t, strings.Contains(buffer.String(), runtime.GOOS)) } diff --git a/commands_generate.go b/commands_generate.go index 00b47907..9b8ed085 100644 --- a/commands_generate.go +++ b/commands_generate.go @@ -29,7 +29,7 @@ var zshCompletionScript string //go:embed contrib/completion/fish-completion.fish var fishCompletionScript string -func generateCommand(output io.Writer, ctx commandContext) (err error) { +func generateCommand(ctx commandContext) (err error) { args := ctx.request.arguments // enforce no-log logger := clog.GetDefaultLogger() @@ -37,18 +37,18 @@ func generateCommand(output io.Writer, ctx commandContext) (err error) { logger.SetHandler(clog.NewDiscardHandler()) if slices.Contains(args, "--bash-completion") { - _, err = fmt.Fprintln(output, bashCompletionScript) + _, err = ctx.terminal.Println(bashCompletionScript) } else if slices.Contains(args, "--config-reference") { - err = generateConfigReference(output, args[slices.Index(args, "--config-reference")+1:]) + err = generateConfigReference(ctx.terminal, args[slices.Index(args, "--config-reference")+1:]) } else if slices.Contains(args, "--json-schema") { - err = generateJsonSchema(output, args[slices.Index(args, "--json-schema")+1:]) + err = generateJsonSchema(ctx.terminal, args[slices.Index(args, "--json-schema")+1:]) } else if slices.Contains(args, "--random-key") { ctx.flags.resticArgs = args[slices.Index(args, "--random-key"):] - err = randomKey(output, ctx) + err = randomKey(ctx) } else if slices.Contains(args, "--zsh-completion") { - _, err = fmt.Fprintln(output, zshCompletionScript) + _, err = ctx.terminal.Println(zshCompletionScript) } else if slices.Contains(args, "--fish-completion") { - _, err = fmt.Fprintln(output, fishCompletionScript) + _, err = ctx.terminal.Println(fishCompletionScript) } else { err = fmt.Errorf("nothing to generate for: %s", strings.Join(args, ", ")) } diff --git a/commands_schedule.go b/commands_schedule.go index 552ea3da..fbe2d0ec 100644 --- a/commands_schedule.go +++ b/commands_schedule.go @@ -3,7 +3,6 @@ package main import ( "errors" "fmt" - "io" "maps" "slices" "strings" @@ -20,7 +19,7 @@ const ( ) // createSchedule command -func createSchedule(_ io.Writer, ctx commandContext) error { +func createSchedule(ctx commandContext) error { c := ctx.config request := ctx.request args := ctx.request.arguments @@ -76,7 +75,7 @@ func createSchedule(_ io.Writer, ctx commandContext) error { return nil } -func removeSchedule(_ io.Writer, ctx commandContext) error { +func removeSchedule(ctx commandContext) error { var err error c := ctx.config request := ctx.request @@ -116,7 +115,7 @@ func removeSchedule(_ io.Writer, ctx commandContext) error { return nil } -func statusSchedule(w io.Writer, ctx commandContext) error { +func statusSchedule(ctx commandContext) error { c := ctx.config request := ctx.request args := ctx.request.arguments @@ -345,7 +344,7 @@ func prepareScheduledProfile(ctx *Context) { } } -func runSchedule(_ io.Writer, cmdCtx commandContext) error { +func runSchedule(cmdCtx commandContext) error { err := startProfileOrGroup(&cmdCtx.Context, runProfile) if err != nil { return err diff --git a/commands_schedule_test.go b/commands_schedule_test.go index 0550a798..4d591407 100644 --- a/commands_schedule_test.go +++ b/commands_schedule_test.go @@ -108,6 +108,10 @@ func TestStatusScheduleIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -115,13 +119,11 @@ func TestStatusScheduleIntegrationUsingCrontab(t *testing.T) { request: Request{ profile: tc.profileName, }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) if tc.err != nil { assert.ErrorIs(t, err, tc.err) return @@ -157,6 +159,10 @@ func TestRemoveScheduleIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -164,14 +170,12 @@ func TestRemoveScheduleIntegrationUsingCrontab(t *testing.T) { request: Request{ profile: "profile", }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) // this should remove the 2 lines that resemble a resticprofile schedule - err = removeSchedule(output, ctx) + err = removeSchedule(ctx) require.NoError(t, err) result, err := os.ReadFile(crontab) @@ -179,7 +183,7 @@ func TestRemoveScheduleIntegrationUsingCrontab(t *testing.T) { fmt.Println(string(result)) // but one line in error should be left in the crontab - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) require.ErrorIs(t, err, crond.ErrEntryNoMatch) } @@ -202,6 +206,10 @@ func TestCreateScheduleIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -209,14 +217,12 @@ func TestCreateScheduleIntegrationUsingCrontab(t *testing.T) { request: Request{ profile: "profile-schedule-struct", }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) // create one schedule - err = createSchedule(output, ctx) + err = createSchedule(ctx) require.NoError(t, err) result, err := os.ReadFile(crontab) @@ -225,7 +231,7 @@ func TestCreateScheduleIntegrationUsingCrontab(t *testing.T) { fmt.Println(string(result)) output.Reset() - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) require.NoError(t, err) // verify the new schedule was added @@ -258,6 +264,10 @@ func TestCreateScheduleOverwriteExistingIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -265,14 +275,12 @@ func TestCreateScheduleOverwriteExistingIntegrationUsingCrontab(t *testing.T) { request: Request{ arguments: []string{"--all", "--reload", "--no-start"}, }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) // create (or update) two schedules - err = createSchedule(output, ctx) + err = createSchedule(ctx) require.NoError(t, err) result, err := os.ReadFile(crontab) @@ -282,7 +290,7 @@ func TestCreateScheduleOverwriteExistingIntegrationUsingCrontab(t *testing.T) { fmt.Println(string(result)) output.Reset() - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) require.NoError(t, err) // verify the schedule was replaced diff --git a/commands_test.go b/commands_test.go index 38a6ba57..3b1855f0 100644 --- a/commands_test.go +++ b/commands_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "os" "sort" "strings" "testing" @@ -12,6 +11,7 @@ import ( "github.com/creativeprojects/resticprofile/config" "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/schedule" + "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util/collect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,29 +19,31 @@ import ( func TestPanicCommand(t *testing.T) { assert.Panics(t, func() { - _ = panicCommand(nil, commandContext{}) + _ = panicCommand(commandContext{}) }) } func TestRandomKeyOfInvalidSize(t *testing.T) { - assert.Error(t, randomKey(os.Stdout, commandContext{ + assert.Error(t, randomKey(commandContext{ Context: Context{ - flags: commandLineFlags{resticArgs: []string{"restic", "size"}}, + flags: commandLineFlags{resticArgs: []string{"restic", "size"}}, + terminal: term.NewTerminal(), }, })) } func TestRandomKeyOfZeroSize(t *testing.T) { - assert.Error(t, randomKey(os.Stdout, commandContext{ + assert.Error(t, randomKey(commandContext{ Context: Context{ - flags: commandLineFlags{resticArgs: []string{"restic", "0"}}, + flags: commandLineFlags{resticArgs: []string{"restic", "0"}}, + terminal: term.NewTerminal(), }, })) } func TestRandomKey(t *testing.T) { // doesn't look like much, but it's testing the random generator is not throwing an error - assert.NoError(t, randomKey(os.Stdout, commandContext{})) + assert.NoError(t, randomKey(commandContext{Context: Context{terminal: term.NewTerminal()}})) } func TestRemovableSchedules(t *testing.T) { @@ -199,9 +201,12 @@ func TestCompleteCall(t *testing.T) { for _, test := range testTable { t.Run(strings.Join(test.args, " "), func(t *testing.T) { buffer := &strings.Builder{} - assert.Nil(t, completeCommand(buffer, commandContext{ + assert.Nil(t, completeCommand(commandContext{ ownCommands: ownCommands, - Context: Context{request: Request{arguments: test.args}}, + Context: Context{ + request: Request{arguments: test.args}, + terminal: term.NewTerminal(term.WithStdout(buffer)), + }, })) assert.Equal(t, test.expected, buffer.String()) }) @@ -213,33 +218,38 @@ func TestGenerateCommand(t *testing.T) { contextWithArguments := func(args []string) commandContext { t.Helper() - return commandContext{Context: Context{request: Request{arguments: args}}} + return commandContext{ + Context: Context{ + request: Request{arguments: args}, + terminal: term.NewTerminal(term.WithStdout(buffer)), + }, + } } t.Run("--bash-completion", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--bash-completion"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--bash-completion"}))) assert.Equal(t, strings.TrimSpace(bashCompletionScript), strings.TrimSpace(buffer.String())) assert.Contains(t, bashCompletionScript, "#!/usr/bin/env bash") }) t.Run("--zsh-completion", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--zsh-completion"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--zsh-completion"}))) assert.Equal(t, strings.TrimSpace(zshCompletionScript), strings.TrimSpace(buffer.String())) assert.Contains(t, zshCompletionScript, "#!/usr/bin/env zsh") }) t.Run("--fish-completion", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--fish-completion"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--fish-completion"}))) assert.Equal(t, strings.TrimSpace(fishCompletionScript), strings.TrimSpace(buffer.String())) assert.Contains(t, fishCompletionScript, "#!/usr/bin/env fish") }) t.Run("--config-reference", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--config-reference", "--to", t.TempDir()}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--config-reference", "--to", t.TempDir()}))) ref := buffer.String() assert.Contains(t, ref, "generating reference.gomd") assert.Contains(t, ref, "generating profile section") @@ -248,7 +258,7 @@ func TestGenerateCommand(t *testing.T) { t.Run("--json-schema global", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "global"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "global"}))) ref := buffer.String() assert.Contains(t, ref, `"$schema"`) assert.Contains(t, ref, "/jsonschema/config-1.json") @@ -262,24 +272,24 @@ func TestGenerateCommand(t *testing.T) { t.Run("--json-schema no-option", func(t *testing.T) { buffer.Reset() - assert.Error(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema"}))) + assert.Error(t, generateCommand(contextWithArguments([]string{"--json-schema"}))) }) t.Run("--json-schema invalid-option", func(t *testing.T) { buffer.Reset() - assert.Error(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "_invalid_"}))) + assert.Error(t, generateCommand(contextWithArguments([]string{"--json-schema", "_invalid_"}))) }) t.Run("--json-schema v1", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "v1"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "v1"}))) ref := buffer.String() assert.Contains(t, ref, "/jsonschema/config-1.json") }) t.Run("--json-schema v2", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "v2"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "v2"}))) ref := buffer.String() assert.Contains(t, ref, "\"profiles\":") assert.Contains(t, ref, "/jsonschema/config-2.json") @@ -287,14 +297,14 @@ func TestGenerateCommand(t *testing.T) { t.Run("--json-schema --version 0.13 v1", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "--version", "0.13", "v1"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "--version", "0.13", "v1"}))) ref := buffer.String() assert.Contains(t, ref, "/jsonschema/config-1-restic-0-13.json") }) t.Run("--random-key", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--random-key", "512"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--random-key", "512"}))) assert.Equal(t, 684, len(strings.TrimSpace(buffer.String()))) }) @@ -303,7 +313,7 @@ func TestGenerateCommand(t *testing.T) { opts := []string{"", "invalid", "--unknown"} for _, option := range opts { buffer.Reset() - err := generateCommand(buffer, contextWithArguments([]string{option})) + err := generateCommand(contextWithArguments([]string{option})) assert.EqualError(t, err, fmt.Sprintf("nothing to generate for: %s", option)) assert.Equal(t, 0, buffer.Len()) } @@ -311,7 +321,6 @@ func TestGenerateCommand(t *testing.T) { } func TestShowSchedules(t *testing.T) { - buffer := &bytes.Buffer{} create := func(command string, at ...string) *config.Schedule { origin := config.ScheduleOrigin("default", command) return config.NewDefaultSchedule(nil, origin, at...) @@ -339,7 +348,9 @@ schedule check@default: lock-mode: default capture-environment: RESTIC_* `) - showSchedules(buffer, schedules) + buffer := &bytes.Buffer{} + terminal := term.NewTerminal(term.WithStdout(buffer)) + showSchedules(terminal.Stdout(), schedules) assert.Equal(t, expected, strings.TrimSpace(buffer.String())) } @@ -348,7 +359,7 @@ func TestCreateScheduleWhenNoneAvailable(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = createSchedule(nil, commandContext{ + err = createSchedule(commandContext{ Context: Context{ config: cfg, flags: commandLineFlags{ @@ -366,7 +377,7 @@ func TestCreateScheduleAll(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = createSchedule(nil, commandContext{ + err = createSchedule(commandContext{ Context: Context{ config: cfg, request: Request{ @@ -426,7 +437,7 @@ func TestRunScheduleNoScheduleName(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = runSchedule(nil, commandContext{ + err = runSchedule(commandContext{ Context: Context{ config: cfg, flags: commandLineFlags{ @@ -443,7 +454,7 @@ func TestRunScheduleWrongScheduleName(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = runSchedule(nil, commandContext{ + err = runSchedule(commandContext{ Context: Context{ request: Request{arguments: []string{"wrong"}}, config: cfg, @@ -461,7 +472,7 @@ func TestRunScheduleProfileUnknown(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = runSchedule(nil, commandContext{ + err = runSchedule(commandContext{ Context: Context{ request: Request{arguments: []string{"backup@profile"}}, config: cfg, @@ -472,6 +483,6 @@ func TestRunScheduleProfileUnknown(t *testing.T) { func TestBatteryCommand(t *testing.T) { buffer := &bytes.Buffer{} - err := batteryCommand(buffer, commandContext{}) + err := batteryCommand(commandContext{Context: Context{terminal: term.NewTerminal(term.WithStdout(buffer))}}) require.NoError(t, err) } diff --git a/config/flag.go b/config/flag.go index 35ab236c..b4c7bcd6 100644 --- a/config/flag.go +++ b/config/flag.go @@ -232,7 +232,10 @@ func stringify(value reflect.Value, onlySimplyValues bool) ([]string, bool) { sort.Strings(flatMap) return flatMap, len(flatMap) > 0 - case reflect.Interface: + case reflect.Interface, reflect.Pointer: + if value.IsNil() { + return emptyStringArray, false + } return stringify(value.Elem(), onlySimplyValues) default: diff --git a/config/flag_test.go b/config/flag_test.go index ac126f47..8b60c885 100644 --- a/config/flag_test.go +++ b/config/flag_test.go @@ -10,11 +10,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestPointerValueShouldReturnErrorMessage(t *testing.T) { +func TestPointerValueShouldReturnValue(t *testing.T) { concrete := "test" value := &concrete argValue, _ := stringifyValueOf(value) - assert.Equal(t, []string{"ERROR: unexpected type ptr"}, argValue) + assert.Equal(t, []string{"test"}, argValue) } func TestNilValueFlag(t *testing.T) { diff --git a/config/info.go b/config/info.go index 7c85d44c..880a1a02 100644 --- a/config/info.go +++ b/config/info.go @@ -672,11 +672,10 @@ func NewProfileInfoForRestic(resticVersion string, withDefaultOptions bool) Prof // Building initial set including generic sections (from data model) profileSet := propertySetFromType(infoTypes.profile) { - genericSection := propertySetFromType(infoTypes.genericSection) for _, name := range infoTypes.genericSectionNames { pi := new(propertyInfo) pi.nested = &namedPropertySet{ - propertySet: genericSection, + propertySet: propertySetFromType(infoTypes.genericSection), name: name, } profileSet.properties[name] = pi diff --git a/context.go b/context.go index 7a968951..01f1efd2 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ import ( "github.com/creativeprojects/resticprofile/config" "github.com/creativeprojects/resticprofile/constants" + "github.com/creativeprojects/resticprofile/term" ) type Request struct { @@ -17,8 +18,8 @@ type Request struct { } // Context for running a profile command. -// Not everything is always available, -// but any information should be added to the context as soon as known. +// All fields are not always populated at the same time, +// as the context is built up step by step when running a profile command. type Context struct { request Request flags commandLineFlags @@ -35,6 +36,7 @@ type Context struct { noLock bool // skip profile lock file lockWait time.Duration // wait up to duration to acquire a lock legacyArgs bool // I'm not even sure it's been used by anyone? + terminal *term.Terminal } func CreateContext(flags commandLineFlags, global *config.Global, cfg *config.Config, ownCommands *OwnCommands) (*Context, error) { @@ -135,6 +137,12 @@ func (c *Context) WithProfile(profileName string) *Context { return newContext } +func (c *Context) WithTerminal(terminal *term.Terminal) *Context { + newContext := c.clone() + newContext.terminal = terminal + return newContext +} + func (c *Context) clone() *Context { clone := *c return &clone diff --git a/context_test.go b/context_test.go index cecd55a7..9e6e462f 100644 --- a/context_test.go +++ b/context_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/creativeprojects/resticprofile/config" + "github.com/creativeprojects/resticprofile/term" "github.com/stretchr/testify/assert" ) @@ -56,6 +57,13 @@ func TestContextWithGroup(t *testing.T) { assert.NotEmpty(t, ctx.request.command) } +func TestContextWithTerminal(t *testing.T) { + terminal := term.NewTerminal() + ctx := &Context{} + ctx = ctx.WithTerminal(terminal) + assert.Same(t, terminal, ctx.terminal) +} + func TestContextWithProfile(t *testing.T) { ctx := &Context{ request: Request{ diff --git a/flags.go b/flags.go index 87dfd51c..1b51c7fd 100644 --- a/flags.go +++ b/flags.go @@ -139,7 +139,7 @@ func loadFlags(args []string) (*pflag.FlagSet, commandLineFlags, error) { flagset.SetInterspersed(false) // Store usage help for help command - width, _ := term.OsStdoutTerminalSize() + width, _ := term.Size() flags.usagesHelp = flagset.FlagUsagesWrapped(width) if err := flagset.Parse(args); err != nil { diff --git a/integration_test.go b/integration_test.go index 5e8aea48..bbf45e59 100644 --- a/integration_test.go +++ b/integration_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "os" "path/filepath" "strings" "testing" @@ -121,22 +120,23 @@ func TestFromConfigFileToCommandLine(t *testing.T) { require.NoError(t, err) require.NotNil(t, profile) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) + ctx := &Context{ request: Request{ profile: fixture.profileName, arguments: fixture.cmdlineArgs, }, - binary: echoBinary, - profile: profile, - command: fixture.commandName, + binary: echoBinary, + profile: profile, + command: fixture.commandName, + terminal: terminal, } wrapper := newResticWrapper(ctx) - buffer := &bytes.Buffer{} - // setting the output via the package global setter could lead to some issues - // when some tests are running in parallel. I should fix that at some point :-/ - term.SetOutput(buffer) + err = wrapper.runCommand(fixture.commandName) - term.SetOutput(os.Stdout) + stdout := buffer.String() require.NoError(t, err) @@ -148,7 +148,7 @@ func TestFromConfigFileToCommandLine(t *testing.T) { t.SkipNow() } } - assert.Equal(t, expected, strings.TrimSpace(buffer.String())) + assert.Equal(t, expected, strings.TrimSpace(stdout)) }) if platform.IsWindows() { @@ -161,6 +161,9 @@ func TestFromConfigFileToCommandLine(t *testing.T) { require.NoError(t, err) require.NotNil(t, profile) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) + ctx := &Context{ request: Request{ profile: fixture.profileName, @@ -170,18 +173,16 @@ func TestFromConfigFileToCommandLine(t *testing.T) { profile: profile, command: fixture.commandName, legacyArgs: true, + terminal: terminal, } wrapper := newResticWrapper(ctx) - buffer := &bytes.Buffer{} - // setting the output via the package global setter could lead to some issues - // when some tests are running in parallel. I should fix that at some point :-/ - term.SetOutput(buffer) + err = wrapper.runCommand(fixture.commandName) - term.SetOutput(os.Stdout) + content := buffer.String() require.NoError(t, err) - assert.Equal(t, fixture.legacy, strings.TrimSpace(buffer.String())) + assert.Equal(t, fixture.legacy, strings.TrimSpace(content)) }) } }) diff --git a/logger.go b/logger.go index bcb0619c..b612b72e 100644 --- a/logger.go +++ b/logger.go @@ -8,7 +8,6 @@ import ( "path/filepath" "slices" "strings" - "time" "github.com/creativeprojects/clog" "github.com/creativeprojects/resticprofile/constants" @@ -18,6 +17,7 @@ import ( "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util" "github.com/creativeprojects/resticprofile/util/collect" + "github.com/creativeprojects/resticprofile/util/write" "github.com/fatih/color" ) @@ -50,7 +50,7 @@ func setupRemoteLogger(flags commandLineFlags, client *remote.Client) { clog.SetDefaultLogger(logger) } -func setupTargetLogger(flags commandLineFlags, logTarget, commandOutput string) (io.Closer, error) { +func setupTargetLogger(flags commandLineFlags, terminal *term.Terminal, logTarget, commandOutput string) (io.Closer, []term.TerminalOption, error) { var ( handler LogCloser file io.Writer @@ -64,29 +64,36 @@ func setupTargetLogger(flags commandLineFlags, logTarget, commandOutput string) handler, file, err = getFileHandler(logTarget) } if err != nil { - return nil, err + return nil, nil, err } // use the console handler as a backup logger := newFilteredLogger(flags, clog.NewSafeHandler(handler, clog.NewConsoleHandler("", log.LstdFlags))) // default logger added with level filtering clog.SetDefaultLogger(logger) + var terminalOptions []term.TerminalOption // also redirect all terminal output if file != nil { - if all, toLog := parseCommandOutput(commandOutput); all { - term.SetOutput(io.MultiWriter(file, term.GetOutput())) - term.SetErrorOutput(io.MultiWriter(file, term.GetErrorOutput())) + if all, toLog := parseCommandOutput(terminal, commandOutput); all { + clog.Debugf("sending a copy of the console logs to %q", logTarget) + terminalOptions = []term.TerminalOption{ + term.WithStdoutCopy(file), + term.WithStderrCopy(file), + } } else if toLog { - term.SetAllOutput(file) + terminalOptions = []term.TerminalOption{ + term.WithStdout(file), + term.WithStderr(file), + } } } // and return the handler (so we can close it at the end) - return handler, nil + return handler, terminalOptions, nil } -func parseCommandOutput(commandOutput string) (all, log bool) { +func parseCommandOutput(terminal *term.Terminal, commandOutput string) (all, log bool) { if strings.TrimSpace(commandOutput) == "auto" { - if term.OsStdoutIsTerminal() { + if terminal.StdoutIsTerminal() { commandOutput = "log,console" } else { commandOutput = "log" @@ -114,9 +121,8 @@ func getFileHandler(logfile string) (*clog.StandardLogHandler, io.Writer, error) } // create a platform aware log file appender - keepOpen, appender := true, appendFunc(nil) + var appender write.WriterAppendFunc if platform.IsWindows() { - keepOpen = false appender = func(dst []byte, c byte) []byte { switch c { case '\n': @@ -128,10 +134,11 @@ func getFileHandler(logfile string) (*clog.StandardLogHandler, io.Writer, error) } } - writer, err := newDeferredFileWriter(logfile, keepOpen, appender) + file, err := write.NewFile(logfile, write.WithFilePerm(0644)) if err != nil { return nil, nil, err } + writer := write.NewAsync(write.NewAppend(file, appender)) return clog.NewStandardLogHandler(writer, "", log.LstdFlags), writer, nil } @@ -173,140 +180,3 @@ func changeLevelFilter(level clog.LogLevel) { filter.SetLevel(level) } } - -// deferredFileWriter accumulates Write requests and writes them at a fixed rate (every 250 ms) -type deferredFileWriter struct { - done, flush chan chan error - data chan []byte -} - -func (d *deferredFileWriter) Close() error { - req := make(chan error) - d.done <- req - return <-req -} - -func (d *deferredFileWriter) Flush() error { - req := make(chan error) - d.flush <- req - return <-req -} - -func (d *deferredFileWriter) Write(data []byte) (n int, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic: %v", r) - } - }() - c := make([]byte, len(data)) - n = copy(c, data) - d.data <- c - return -} - -type appendFunc func(dst []byte, c byte) []byte - -func newDeferredFileWriter(filename string, keepOpen bool, appender appendFunc) (io.WriteCloser, error) { - d := &deferredFileWriter{ - flush: make(chan chan error), - done: make(chan chan error), - data: make(chan []byte, 64), - } - - var ( - buffer []byte - lastError error - file *os.File - ) - - closeFile := func() { - if file != nil { - lastError = file.Close() - file = nil - } - } - - flush := func(alsoEmpty bool) { - if len(buffer) == 0 && !alsoEmpty { - return - } - if file == nil { - file, lastError = os.OpenFile(filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) //nolint:gosec - } - if file != nil { - var written int - written, lastError = file.Write(buffer) - if written == len(buffer) { - buffer = buffer[:0] - } else { - buffer = buffer[written:] - } - } - if keepOpen { - _ = file.Sync() - } else { - closeFile() - } - } - - // test if we can create the file - buffer = make([]byte, 0, 4096) - flush(true) - - // data appending - addToBuffer := func(data []byte) { - buffer = append(buffer, data...) // fast path - } - if appender != nil { - addToBuffer = func(data []byte) { - for _, c := range data { - buffer = appender(buffer, c) - } - } - } - - addPendingData := func(size int) { - for ; size > 0; size-- { - select { - case data, ok := <-d.data: - if ok { - addToBuffer(data) - } else { - return // closed - } - default: - return // no-more-data - } - } - } - - // data transport - go func() { - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case data := <-d.data: - addToBuffer(data) - case <-ticker.C: - flush(false) - case req := <-d.flush: - addPendingData(1024) - flush(false) - req <- lastError - case req := <-d.done: - close(d.done) - close(d.flush) - close(d.data) - addPendingData(1024) - flush(false) - closeFile() - req <- lastError - return - } - } - }() - - return d, lastError -} diff --git a/logger_test.go b/logger_test.go index 43cb0afe..ae335e0a 100644 --- a/logger_test.go +++ b/logger_test.go @@ -53,8 +53,8 @@ func TestFileHandler(t *testing.T) { require.NoError(t, err) defer handler.Close() - require.Implements(t, (*term.Flusher)(nil), writer) - flusher := writer.(term.Flusher) + require.Implements(t, (*util.Flusher)(nil), writer) + flusher := writer.(util.Flusher) log := func(line string) { assert.NoError(t, handler.LogEntry(clog.LogEntry{Level: clog.LevelInfo, Format: line})) @@ -99,12 +99,13 @@ func TestFileHandler(t *testing.T) { } func TestParseCommandOutput(t *testing.T) { + terminal := term.NewTerminal() tests := []struct { co string all, log bool }{ {co: "", all: false, log: false}, - {co: "auto", all: term.OsStdoutIsTerminal(), log: true}, + {co: "auto", all: terminal.StdoutIsTerminal(), log: true}, {co: "log", all: false, log: true}, {co: "console", all: false, log: false}, {co: "all", all: true, log: false}, @@ -114,7 +115,7 @@ func TestParseCommandOutput(t *testing.T) { {co: "log,a", all: false, log: true}, {co: "console,a", all: false, log: false}, - {co: " auto ", all: term.OsStdoutIsTerminal(), log: true}, + {co: " auto ", all: terminal.StdoutIsTerminal(), log: true}, {co: " all ", all: true, log: false}, {co: " log ", all: false, log: true}, {co: " console ", all: false, log: false}, @@ -124,7 +125,7 @@ func TestParseCommandOutput(t *testing.T) { for i, test := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - a, l := parseCommandOutput(test.co) + a, l := parseCommandOutput(terminal, test.co) assert.Equal(t, test.all, a, "all") assert.Equal(t, test.log, l, "log") }) diff --git a/main.go b/main.go index c06edfc1..e8f72c7a 100644 --- a/main.go +++ b/main.go @@ -48,29 +48,40 @@ func main() { // run shutdown hooks just before returning an exit code defer shutdown.RunHooks() + // prepare a default terminal before loading the terminal configuration + terminal := term.Set(term.NewTerminal()) + terminalOptions := make([]term.TerminalOption, 0, 5) + args := os.Args[1:] _, flags, flagErr := loadFlags(args) if flagErr != nil && !errors.Is(flagErr, pflag.ErrHelp) { - term.Println(flagErr) - _ = displayHelpCommand(os.Stdout, commandContext{ + terminal.Println(flagErr) + _ = displayHelpCommand(commandContext{ ownCommands: ownCommands, Context: Context{ flags: flags, request: Request{ arguments: args, }, + terminal: terminal, }, }) exitCode = constants.ExitErrorInvalidFlags return } + // Now that we have loaded the flags, configure terminal color + if flags.noAnsi || flags.theme == "none" { + terminalOptions = append(terminalOptions, term.WithColors(false)) // disable colors + terminal = term.Set(term.NewTerminal(terminalOptions...)) + } + if flags.wait { // keep the console running at the end of the program // so we can see what's going on defer func() { - term.Println("\n\nPress the Enter Key to continue...") - _, _ = fmt.Scanln() + terminal.Println("\n\nPress Enter to continue...") + _, _ = terminal.Scanln() }() } @@ -83,13 +94,14 @@ func main() { // help if flags.help || errors.Is(flagErr, pflag.ErrHelp) { - _ = displayHelpCommand(os.Stdout, commandContext{ + _ = displayHelpCommand(commandContext{ ownCommands: ownCommands, Context: Context{ flags: flags, request: Request{ arguments: args, }, + terminal: terminal, }, }) return @@ -106,42 +118,52 @@ func main() { setupRemoteLogger(flags, client) // also redirect the terminal through the client - term.SetAllOutput(term.NewRemoteTerm(client)) - } else { - logTarget, commandOutput := "", "" - if ctx != nil { - logTarget = ctx.logTarget - commandOutput = ctx.commandOutput - if ctx.request.command == constants.CommandCat || - ctx.request.command == constants.CommandDump { - clog.Debugf("redirecting console to stderr for command %q", ctx.request.command) - flags.stderr = true - } - term.PrintToError = flags.stderr + remoteTerm := term.NewRemoteTerm(client) + terminalOptions = append(terminalOptions, term.WithStdout(remoteTerm), term.WithStderr(remoteTerm)) + + return + } + logTarget, commandOutput := "", "" + if ctx != nil { + logTarget = ctx.logTarget + commandOutput = ctx.commandOutput + if ctx.request.command == constants.CommandCat || + ctx.request.command == constants.CommandDump { + clog.Debugf("redirecting console to stderr for command %q", ctx.request.command) + flags.stderr = true } - if logTarget != "" && logTarget != "-" { - if closer, err := setupTargetLogger(flags, logTarget, commandOutput); err == nil { - logCloser = func() { _ = closer.Close() } - } else { - // fallback to a console logger - setupConsoleLogger(flags) - clog.Errorf("cannot open log target: %s", err) - } - } else { - // use the console logger - setupConsoleLogger(flags) + if flags.stderr { + terminalOptions = append(terminalOptions, term.WithStdout(os.Stderr)) } } + if logTarget != "" && logTarget != "-" { + if closer, options, err := setupTargetLogger(flags, terminal, logTarget, commandOutput); err == nil { + logCloser = func() { _ = closer.Close() } + terminalOptions = append(terminalOptions, options...) + return + } + // fallback to a console logger + setupConsoleLogger(flags) + clog.Errorf("cannot open log target: %s", err) + + return + } + // use the console logger + setupConsoleLogger(flags) + return } + // refresh terminal with new config + terminal = term.Set(term.NewTerminal(terminalOptions...)) + // keep this one last if possible (so it will be first at the end) defer showPanicData() banner() if flags.remote != "" { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) closeFS, remoteParameters, err := setupRemoteConfiguration(ctx, flags.remote) cancel() if err != nil { @@ -171,14 +193,22 @@ func main() { command: flags.resticArgs[0], arguments: flags.resticArgs[1:], }, + logTarget: flags.log, + commandOutput: flags.commandOutput, } + // try to load the config and setup logging for own command cfg, global, err := loadConfig(flags, true) if err == nil { ctx = ctx.WithConfig(cfg, global) } closeLogger := setupLogging(ctx) - defer closeLogger() + shutdown.AddHook(closeLogger) + + // refresh terminal with new logging config + terminal = term.Set(term.NewTerminal(terminalOptions...)) + ctx = ctx.WithTerminal(terminal) + err = ownCommands.Run(ctx) if err != nil { clog.Error(err) @@ -196,13 +226,17 @@ func main() { // Load the now mandatory configuration and setup logging (before returning an error) ctx, err := loadContext(flags) closeLogger := setupLogging(ctx) - defer closeLogger() + shutdown.AddHook(closeLogger) if err != nil { clog.Error(err) exitCode = constants.ExitGeneralError return } + // refresh terminal with new config + terminal = term.Set(term.NewTerminal(terminalOptions...)) + ctx = ctx.WithTerminal(terminal) + // check if we're running on battery if shouldStopOnBattery(ctx.stopOnBattery) { exitCode = constants.ExitRunningOnBattery @@ -220,14 +254,14 @@ func main() { } } // and stop at the end - defer func() { + shutdown.AddHook(func() { if caffeinate != nil { err = caffeinate.Stop() if err != nil { clog.Error(err) } } - }() + }) // Check memory pressure if ctx.global.MinMemory > 0 { @@ -285,8 +319,9 @@ func main() { if err != nil { clog.Error(err) if errors.Is(err, ErrProfileNotFound) { - displayProfiles(os.Stdout, ctx.config, flags) - displayGroups(os.Stdout, ctx.config, flags) + commandContext := commandContext{Context: *ctx} + displayProfiles(commandContext) + displayGroups(commandContext) } exitCode = constants.ExitGeneralError return @@ -412,14 +447,15 @@ func free() uint64 { func showPanicData() { if r := recover(); r != nil { + terminal := term.Get() message := ` -=============================================================== -uh-oh! resticprofile crashed miserably :-( -Can you please open an issue on github including these details: -=============================================================== +================================================================= + uh-oh! resticprofile crashed miserably :-( + Can you please open an issue on github including these details: +================================================================= ` - fmt.Fprint(os.Stderr, message) - w := tabwriter.NewWriter(os.Stderr, 0, 0, 3, ' ', 0) + fmt.Fprint(terminal.Stderr(), message) + w := tabwriter.NewWriter(terminal.Stderr(), 0, 0, 3, ' ', 0) _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "os", runtime.GOOS) _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "arch", runtime.GOARCH) _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "version", version) @@ -429,6 +465,6 @@ Can you please open an issue on github including these details: _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "error", r) _, _ = fmt.Fprintf(w, "\t%s:\n%s\n", "stack", getStack(3)) // skip calls to getStack - showPanicData - panic w.Flush() - fmt.Fprint(os.Stderr, "===============================================================\n") + fmt.Fprintln(terminal.Stderr(), "=================================================================") } } diff --git a/own_commands.go b/own_commands.go index 6db21d04..a68d5dfc 100644 --- a/own_commands.go +++ b/own_commands.go @@ -2,8 +2,6 @@ package main import ( "fmt" - "io" - "os" "strings" "github.com/creativeprojects/clog" @@ -20,14 +18,14 @@ type ownCommand struct { name string description string longDescription string - pre func(*Context) error // pre-command action (for checking the context) - action func(io.Writer, commandContext) error // run command action - needConfiguration bool // true if the action needs a configuration file loaded - hide bool // don't display the command in help and completion - hideInCompletion bool // don't display the command in completion - noProfile bool // true if the command doesn't need a profile name - experimental bool // display a warning when using this command - flags map[string]string // own command flags should be simple enough to be handled manually for now + pre func(*Context) error // pre-command action (for checking the context) + action func(commandContext) error // run command action + needConfiguration bool // true if the action needs a configuration file loaded + hide bool // don't display the command in help and completion + hideInCompletion bool // don't display the command in completion + noProfile bool // true if the command doesn't need a profile name + experimental bool // display a warning when using this command + flags map[string]string // own command flags should be simple enough to be handled manually for now } // OwnCommands is a list of resticprofile commands @@ -68,7 +66,7 @@ func (o *OwnCommands) Run(ctx *Context) error { if command.experimental { clog.Warningf("%s: this command is experimental and its behaviour may change in the future", ctx.request.command) } - return command.action(os.Stdout, commandContext{ + return command.action(commandContext{ ownCommands: o, Context: *ctx, }) diff --git a/own_commands_test.go b/own_commands_test.go index 09334623..386c8e9d 100644 --- a/own_commands_test.go +++ b/own_commands_test.go @@ -2,10 +2,10 @@ package main import ( "errors" - "io" "strings" "testing" + "github.com/creativeprojects/resticprofile/term" "github.com/stretchr/testify/assert" ) @@ -40,15 +40,15 @@ func fakeCommands() *OwnCommands { return ownCommands } -func firstCommand(_ io.Writer, _ commandContext) error { +func firstCommand(_ commandContext) error { return errors.New("first") } -func secondCommand(_ io.Writer, _ commandContext) error { +func secondCommand(_ commandContext) error { return errors.New("second") } -func thirdCommand(_ io.Writer, _ commandContext) error { +func thirdCommand(_ commandContext) error { return errors.New("third") } @@ -58,13 +58,15 @@ func pre(_ *Context) error { func TestDisplayOwnCommands(t *testing.T) { buffer := &strings.Builder{} - displayOwnCommands(buffer, commandContext{ownCommands: fakeCommands()}) + terminal := term.NewTerminal(term.WithStdout(buffer)) + displayOwnCommands(commandContext{ownCommands: fakeCommands(), Context: Context{terminal: terminal}}) assert.Equal(t, " first first first\n second second second\n", buffer.String()) } func TestDisplayOwnCommand(t *testing.T) { buffer := &strings.Builder{} - displayOwnCommandHelp(buffer, "second", commandContext{ownCommands: fakeCommands()}) + terminal := term.NewTerminal(term.WithStdout(buffer)) + displayOwnCommandHelp(commandContext{ownCommands: fakeCommands(), Context: Context{terminal: terminal}}, "second") assert.Equal(t, `Purpose: second second Usage: diff --git a/schedule/handler_crond.go b/schedule/handler_crond.go index 4e819e8b..48b9e31f 100644 --- a/schedule/handler_crond.go +++ b/schedule/handler_crond.go @@ -12,6 +12,7 @@ import ( "github.com/creativeprojects/resticprofile/crond" "github.com/creativeprojects/resticprofile/platform" "github.com/creativeprojects/resticprofile/shell" + "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/user" "github.com/spf13/afero" ) @@ -63,7 +64,7 @@ func (h *HandlerCrond) DisplaySchedules(profile, command string, schedules []str if err != nil { return err } - displayParsedSchedules(profile, command, events) + displayParsedSchedules(term.Get(), profile, command, events) return nil } diff --git a/schedule/handler_darwin.go b/schedule/handler_darwin.go index 5912d0a2..c832c3b6 100644 --- a/schedule/handler_darwin.go +++ b/schedule/handler_darwin.go @@ -69,7 +69,7 @@ func (h *HandlerLaunchd) DisplaySchedules(profile, command string, schedules []s if err != nil { return err } - displayParsedSchedules(profile, command, events) + displayParsedSchedules(term.Get(), profile, command, events) return nil } @@ -151,7 +151,7 @@ func (h *HandlerLaunchd) DisplayJobStatus(job *Config) error { // output was not parsed, it could mean output format has changed clog.Warning("output of 'launchctl print' was either empty or using an incompatible format") } - writer := tabwriter.NewWriter(term.GetOutput(), 1, 1, 1, ' ', tabwriter.AlignRight) + writer := tabwriter.NewWriter(term.Get().Stdout(), 1, 1, 1, ' ', tabwriter.AlignRight) for _, key := range launchctlPrintKeys { key, value := presentStatus(key, status[key]) if len(value) == 0 { diff --git a/schedule/handler_systemd.go b/schedule/handler_systemd.go index b5bccc83..02e3a4e0 100644 --- a/schedule/handler_systemd.go +++ b/schedule/handler_systemd.go @@ -91,7 +91,7 @@ func (h *HandlerSystemd) ParseSchedules(schedules []string) ([]*calendar.Event, // DisplaySchedules displays the schedules through the systemd-analyze command func (h *HandlerSystemd) DisplaySchedules(profile, command string, schedules []string) error { - return displaySystemdSchedules(profile, command, schedules) + return displaySystemdSchedules(term.Get(), profile, command, schedules) } // DisplayStatus displays the status of all the timers installed on that profile. Example: @@ -117,7 +117,7 @@ func (h *HandlerSystemd) DisplayStatus(profileName string) error { // fail silently return nil //nolint:nilerr } - fmt.Fprintf(term.GetOutput(), "\nTimers summary\n===============\n%s\n", status) + fmt.Fprintf(term.Get().Stdout(), "\nTimers summary\n===============\n%s\n", status) return nil } @@ -424,8 +424,9 @@ func runSystemctlOnUnit(timerName, command string, unitType systemd.UnitType, si return err } if !silent { - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + terminal := term.Get() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() } err = cmd.Run() if command == systemctlStatus && cmd.ProcessState.ExitCode() == codeStatusUnitNotFound { @@ -454,8 +455,9 @@ func runJournalCtlCommand(timerName string, unitType systemd.UnitType) error { } clog.Debugf("starting command \"%s %s\"", binary, strings.Join(args, " ")) cmd := exec.CommandContext(context.TODO(), binary, args...) - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + terminal := term.Get() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() err = cmd.Run() fmt.Println("") return err @@ -470,8 +472,9 @@ func runSystemctlReload(unitType systemd.UnitType) error { if err != nil { return err } - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + terminal := term.Get() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() err = cmd.Run() if err != nil { return err @@ -597,7 +600,7 @@ func systemctlCommand(args ...string) (*exec.Cmd, error) { return exec.CommandContext(context.TODO(), binary, args...), nil } -func displaySystemdSchedules(profile, command string, schedules []string) error { +func displaySystemdSchedules(terminal *term.Terminal, profile, command string, schedules []string) error { binary, err := exec.LookPath(analyzeBinary) if err != nil { return fmt.Errorf("cannot find %q: %w", analyzeBinary, err) @@ -607,17 +610,17 @@ func displaySystemdSchedules(profile, command string, schedules []string) error if schedule == "" { return errors.New("empty schedule") } - displayHeader(profile, command, index+1, len(schedules)) + displayHeader(terminal, profile, command, index+1, len(schedules)) cmd := exec.CommandContext(context.TODO(), binary, "calendar", schedule) - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() err = cmd.Run() if err != nil { return err } } - term.Print(platform.LineSeparator) + terminal.Print(platform.LineSeparator) return nil } diff --git a/schedule/handler_systemd_test.go b/schedule/handler_systemd_test.go index 615ff652..defa20df 100644 --- a/schedule/handler_systemd_test.go +++ b/schedule/handler_systemd_test.go @@ -262,16 +262,15 @@ func TestPermissionToSystemd(t *testing.T) { } func TestDisplaySystemdSchedulesWithEmpty(t *testing.T) { - err := displaySystemdSchedules("profile", "command", []string{""}) + err := displaySystemdSchedules(term.NewTerminal(), "profile", "command", []string{""}) require.Error(t, err) } func TestDisplaySystemdSchedules(t *testing.T) { buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + terminal := term.NewTerminal(term.WithStdout(buffer)) - err := displaySystemdSchedules("profile", "command", []string{"daily"}) + err := displaySystemdSchedules(terminal, "profile", "command", []string{"daily"}) require.NoError(t, err) output := buffer.String() diff --git a/schedule/handler_windows.go b/schedule/handler_windows.go index 988b94ee..54111f70 100644 --- a/schedule/handler_windows.go +++ b/schedule/handler_windows.go @@ -10,6 +10,7 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/schtasks" "github.com/creativeprojects/resticprofile/shell" + "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/user" ) @@ -38,7 +39,7 @@ func (h *HandlerWindows) DisplaySchedules(profile, command string, schedules []s if err != nil { return err } - displayParsedSchedules(profile, command, events) + displayParsedSchedules(term.Get(), profile, command, events) return nil } diff --git a/schedule/schedules.go b/schedule/schedules.go index 7daf73cb..3f261882 100644 --- a/schedule/schedules.go +++ b/schedule/schedules.go @@ -11,16 +11,16 @@ import ( "github.com/creativeprojects/resticprofile/term" ) -func displayHeader(profile, command string, index, total int) { - term.Print(platform.LineSeparator) +func displayHeader(terminal *term.Terminal, profile, command string, index, total int) { + terminal.Print(platform.LineSeparator) header := fmt.Sprintf("Profile (or Group) %s: %s schedule", profile, command) if total > 1 { header += fmt.Sprintf(" %d/%d", index, total) } - term.Print(header) - term.Print(platform.LineSeparator) - term.Print(strings.Repeat("=", len(header))) - term.Print(platform.LineSeparator) + terminal.Print(header) + terminal.Print(platform.LineSeparator) + terminal.Print(strings.Repeat("=", len(header))) + terminal.Print(platform.LineSeparator) } // parseSchedules creates a *calendar.Event from a string @@ -40,20 +40,20 @@ func parseSchedules(schedules []string) ([]*calendar.Event, error) { return events, nil } -func displayParsedSchedules(profile, command string, events []*calendar.Event) { +func displayParsedSchedules(terminal *term.Terminal, profile, command string, events []*calendar.Event) { now := time.Now().Round(time.Second) for index, event := range events { - displayHeader(profile, command, index+1, len(events)) + displayHeader(terminal, profile, command, index+1, len(events)) next := event.Next(now) - term.Printf(" Original form: %s\n", event.Input()) - term.Printf("Normalized form: %s\n", event.String()) + terminal.Printf(" Original form: %s\n", event.Input()) + terminal.Printf("Normalized form: %s\n", event.String()) if next.IsZero() { - term.Print(" Next elapse: never\n") + terminal.Print(" Next elapse: never\n") continue } - term.Printf(" Next elapse: %s\n", next.Format(time.UnixDate)) - term.Printf(" (in UTC): %s\n", next.UTC().Format(time.UnixDate)) - term.Printf(" From now: %s left\n", next.Sub(now)) + terminal.Printf(" Next elapse: %s\n", next.Format(time.UnixDate)) + terminal.Printf(" (in UTC): %s\n", next.UTC().Format(time.UnixDate)) + terminal.Printf(" From now: %s left\n", next.Sub(now)) } - term.Print(platform.LineSeparator) + terminal.Print(platform.LineSeparator) } diff --git a/schedule/schedules_test.go b/schedule/schedules_test.go index 1c014fef..6a7357c9 100644 --- a/schedule/schedules_test.go +++ b/schedule/schedules_test.go @@ -2,7 +2,6 @@ package schedule import ( "bytes" - "os" "testing" "github.com/creativeprojects/resticprofile/term" @@ -37,11 +36,10 @@ func TestDisplayParseSchedules(t *testing.T) { events, err := parseSchedules([]string{"daily"}) require.NoError(t, err) - buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) - displayParsedSchedules("profile", "command", events) + displayParsedSchedules(terminal, "profile", "command", events) output := buffer.String() assert.Contains(t, output, "Original form: daily\n") assert.Contains(t, output, "Normalized form: *-*-* 00:00:00\n") @@ -51,11 +49,10 @@ func TestDisplayParseSchedulesWillNeverRun(t *testing.T) { events, err := parseSchedules([]string{"2020-01-01"}) require.NoError(t, err) - buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) - displayParsedSchedules("profile", "command", events) + displayParsedSchedules(terminal, "profile", "command", events) output := buffer.String() assert.Contains(t, output, "Next elapse: never\n") } @@ -64,11 +61,10 @@ func TestDisplayParseSchedulesIndexAndTotal(t *testing.T) { events, err := parseSchedules([]string{"daily", "monthly", "yearly"}) require.NoError(t, err) - buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) - displayParsedSchedules("profile", "command", events) + displayParsedSchedules(terminal, "profile", "command", events) output := buffer.String() assert.Contains(t, output, "schedule 1/3") assert.Contains(t, output, "schedule 2/3") diff --git a/schtasks/permission.go b/schtasks/permission.go index 3f212754..5004feb6 100644 --- a/schtasks/permission.go +++ b/schtasks/permission.go @@ -40,7 +40,7 @@ func userCredentials() (string, string, error) { fmt.Printf("\nCreating task for user %s\n", userName) fmt.Printf("Task Scheduler requires your Windows password to validate the task: ") - userPassword, err = term.ReadPassword() + userPassword, err = term.Get().ReadPassword() if err != nil { return "", "", err } diff --git a/schtasks/taskscheduler.go b/schtasks/taskscheduler.go index 06e1871c..3a44bea5 100644 --- a/schtasks/taskscheduler.go +++ b/schtasks/taskscheduler.go @@ -125,7 +125,7 @@ func Status(title, subtitle string) error { if len(info) < 2 { return ErrNotRegistered } - writer := tabwriter.NewWriter(term.GetOutput(), 2, 2, 2, ' ', tabwriter.AlignRight) + writer := tabwriter.NewWriter(term.Get().Stdout(), 2, 2, 2, ' ', tabwriter.AlignRight) fmt.Fprintf(writer, "Task:\t %s\n", getFirstField(info, "TaskName")) fmt.Fprintf(writer, "User:\t %s\n", getFirstField(info, "Run As User")) fmt.Fprintf(writer, "Logon Mode:\t %s\n", getFirstField(info, "Logon Mode")) diff --git a/serve.go b/serve.go index ef043f44..cc254011 100644 --- a/serve.go +++ b/serve.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "os" "os/signal" @@ -19,7 +18,7 @@ import ( "github.com/creativeprojects/resticprofile/ssh" ) -func serveCommand(w io.Writer, cmdCtx commandContext) error { +func serveCommand(cmdCtx commandContext) error { if len(cmdCtx.flags.resticArgs) < 2 { return fmt.Errorf("missing argument: port") } @@ -78,7 +77,7 @@ func serveProfiles(port string, config *config.Config, quit chan os.Signal) erro return nil } -func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { +func sendProfileCommand(cmdCtx commandContext) error { if len(cmdCtx.flags.resticArgs) < 2 { return fmt.Errorf("missing argument: remote name") } diff --git a/term/default.go b/term/default.go new file mode 100644 index 00000000..cf98883b --- /dev/null +++ b/term/default.go @@ -0,0 +1,34 @@ +package term + +import ( + "os" + "sync/atomic" + + "golang.org/x/term" +) + +var defaultTerminal atomic.Pointer[Terminal] + +// Get returns the default terminal. It will be initialized on first use with NewTerminal() if not set before. +// Ideally we should pass the terminal where we need it, but Get() can be safely used until the refactoring is finished. +func Get() *Terminal { + defaultTerminal.CompareAndSwap(nil, NewTerminal()) + return defaultTerminal.Load() +} + +// Set stores the default terminal, and returns the terminal reference for chaining. +func Set(t *Terminal) *Terminal { + defaultTerminal.Store(t) + return t +} + +// Size returns the width and height of the terminal session +func Size() (width, height int) { + fd := fdToInt(os.Stdout.Fd()) + var err error + width, height, err = term.GetSize(fd) + if err != nil { + width, height = 0, 0 + } + return +} diff --git a/term/default_test.go b/term/default_test.go new file mode 100644 index 00000000..f34eedfa --- /dev/null +++ b/term/default_test.go @@ -0,0 +1,43 @@ +package term + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetTerminalSingleton(t *testing.T) { + total := 10 + wg := new(sync.WaitGroup) + terminals := make([]*Terminal, total) + for i := range total { + wg.Go(func() { + terminals[i] = Get() + }) + } + wg.Wait() + + for i := range total { + assert.NotNil(t, terminals[i]) + assert.Same(t, terminals[0], terminals[i]) + } +} + +func TestSetAndGetTerminalSingleton(t *testing.T) { + terminal := Set(NewTerminal()) + total := 10 + wg := new(sync.WaitGroup) + terminals := make([]*Terminal, total) + for i := range total { + wg.Go(func() { + terminals[i] = Get() + }) + } + wg.Wait() + + for i := range total { + assert.NotNil(t, terminals[i]) + assert.Same(t, terminal, terminals[i]) + } +} diff --git a/term/nil_reader.go b/term/nil_reader.go new file mode 100644 index 00000000..4ae07109 --- /dev/null +++ b/term/nil_reader.go @@ -0,0 +1,9 @@ +package term + +import "io" + +type nilReader struct{} + +func (nilReader) Read(p []byte) (int, error) { + return 0, io.EOF +} diff --git a/term/nil_reader_test.go b/term/nil_reader_test.go new file mode 100644 index 00000000..0d7d1895 --- /dev/null +++ b/term/nil_reader_test.go @@ -0,0 +1,16 @@ +package term + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadingFromNilReader(t *testing.T) { + r := new(nilReader) + buffer, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, 0, len(buffer)) +} diff --git a/term/recorder.go b/term/recorder.go new file mode 100644 index 00000000..756bf711 --- /dev/null +++ b/term/recorder.go @@ -0,0 +1,39 @@ +package term + +import ( + "errors" + "fmt" + "io" + "os" + "sync" +) + +type Recorder struct { + inputWriter *os.File + inputReader *os.File + wg sync.WaitGroup +} + +func NewRecorder(output io.Writer) (*Recorder, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("cannot create OS pipe: %w", err) + } + recorder := &Recorder{ + inputReader: r, + inputWriter: w, + } + recorder.wg.Go(func() { + _, _ = io.Copy(output, recorder.inputReader) + }) + return recorder, nil +} + +func (recorder *Recorder) Close() error { + var err error + err = errors.Join(err, recorder.inputWriter.Close()) + recorder.wg.Wait() + // now that we finished reading, we can close the reader side + err = errors.Join(err, recorder.inputReader.Close()) + return err +} diff --git a/term/term.go b/term/term.go deleted file mode 100644 index 5570c5fa..00000000 --- a/term/term.go +++ /dev/null @@ -1,194 +0,0 @@ -package term - -import ( - "bufio" - "fmt" - "io" - "os" - "strings" - "sync" - - colorable "github.com/mattn/go-colorable" - "golang.org/x/term" -) - -var ( - terminalOutput io.Writer = os.Stdout - errorOutput io.Writer = os.Stderr - PrintToError = false -) - -// Flusher allows a Writer to declare it may buffer content that can be flushed -type Flusher interface { - // Flush writes any pending bytes to output - Flush() error -} - -// AskYesNo prompts the user for a message asking for a yes/no answer -func AskYesNo(reader io.Reader, message string, defaultAnswer bool) bool { - if !strings.HasSuffix(message, "?") { - message += "?" - } - var question, input string - if defaultAnswer { - question = "(Y/n)" - input = "y" - } else { - question = "(y/N)" - input = "n" - } - fmt.Printf("%s %s: ", message, question) - scanner := bufio.NewScanner(reader) - if scanner.Scan() { - input = strings.TrimSpace(strings.ToLower(scanner.Text())) - if len(input) > 1 { - // take only the first character - input = input[:1] - } - } - - if input == "" { - return defaultAnswer - } - if input == "y" { - return true - } - return false -} - -// ReadPassword reads a password without echoing it to the terminal. -func ReadPassword() (string, error) { - stdin := fdToInt(os.Stdin.Fd()) - if !term.IsTerminal(stdin) { - return ReadLine() - } - line, err := term.ReadPassword(stdin) - _, _ = fmt.Fprintln(os.Stderr) - if err != nil { - return "", fmt.Errorf("failed to read password: %w", err) - } - return string(line), nil -} - -// ReadLine reads some input -func ReadLine() (string, error) { - buf := bufio.NewReader(os.Stdin) - line, err := buf.ReadString('\n') - if err != nil { - return "", fmt.Errorf("failed to read line: %w", err) - } - return strings.TrimSpace(line), nil -} - -// OsStdoutIsTerminal returns true as os.Stdout is a terminal session -func OsStdoutIsTerminal() bool { - fd := fdToInt(os.Stdout.Fd()) - return term.IsTerminal(fd) -} - -// OsStdoutTerminalSize returns the current width and height of os.Stdout -func OsStdoutTerminalSize() (width, height int) { - fd := fdToInt(os.Stdout.Fd()) - var err error - width, height, err = term.GetSize(fd) - if err != nil { - width, height = 0, 0 - } - return -} - -func fdToInt(fd uintptr) int { - return int(fd) //nolint:gosec -} - -type LockedWriter struct { - writer io.Writer - mutex *sync.Mutex -} - -func (w *LockedWriter) Write(p []byte) (n int, err error) { - w.mutex.Lock() - defer w.mutex.Unlock() - return w.writer.Write(p) -} - -func (w *LockedWriter) Flush() (err error) { - w.mutex.Lock() - defer w.mutex.Unlock() - if f, ok := w.writer.(Flusher); ok { - err = f.Flush() - } - return -} - -// SetOutput changes the default output for the Print* functions -func SetOutput(w io.Writer) { - terminalOutput = &LockedWriter{writer: w, mutex: new(sync.Mutex)} -} - -// GetOutput returns the default output of the Print* functions -func GetOutput() io.Writer { - return terminalOutput -} - -// GetColorableOutput returns an output supporting ANSI color if output is a terminal -func GetColorableOutput() io.Writer { - out := GetOutput() - if out == os.Stdout && OsStdoutIsTerminal() { - return colorable.NewColorable(os.Stdout) - } - return colorable.NewNonColorable(out) -} - -// SetErrorOutput changes the error output for the Print* functions -func SetErrorOutput(w io.Writer) { - errorOutput = &LockedWriter{writer: w, mutex: new(sync.Mutex)} -} - -// GetErrorOutput returns the error output of the Print* functions -func GetErrorOutput() io.Writer { - return errorOutput -} - -// SetAllOutput changes the default and error output for the Print* functions -func SetAllOutput(w io.Writer) { - single := new(sync.Mutex) - terminalOutput = &LockedWriter{writer: w, mutex: single} - errorOutput = &LockedWriter{writer: w, mutex: single} -} - -// FlushAllOutput flushes all buffered output (if supported by the underlying Writer). -func FlushAllOutput() { - for _, writer := range []io.Writer{terminalOutput, errorOutput} { - if f, ok := writer.(Flusher); ok { - _ = f.Flush() - } - } -} - -// Print formats using the default formats for its operands and writes to standard output. -// Spaces are added between operands when neither is a string. -// It returns the number of bytes written and any write error encountered. -func Print(a ...any) (n int, err error) { - return fmt.Fprint(outputWriter(), a...) -} - -// Println formats using the default formats for its operands and writes to standard output. -// Spaces are always added between operands and a newline is appended. -// It returns the number of bytes written and any write error encountered. -func Println(a ...any) (n int, err error) { - return fmt.Fprintln(outputWriter(), a...) -} - -// Printf formats according to a format specifier and writes to standard output. -// It returns the number of bytes written and any write error encountered. -func Printf(format string, a ...any) (n int, err error) { - return fmt.Fprintf(outputWriter(), format, a...) -} - -func outputWriter() io.Writer { - if PrintToError { - return errorOutput - } - return terminalOutput -} diff --git a/term/term_test.go b/term/term_test.go deleted file mode 100644 index 9215747c..00000000 --- a/term/term_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package term - -import ( - "bytes" - "log" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -type askYesNoTestData struct { - input string - defaultAnswer bool - expected bool -} - -func TestAskYesNo(t *testing.T) { - testData := []askYesNoTestData{ - // Empty answer => will follow the defaultAnswer - {"", true, true}, - {"", false, false}, - {"\n", true, true}, - {"\n", false, false}, - {"\r\n", true, true}, - {"\r\n", false, false}, - // Garbage answer => will always return false - {"aa", true, false}, - {"aa", false, false}, - {"aa\n", true, false}, - {"aa\n", false, false}, - {"aa\r\n", true, false}, - {"aa\r\n", false, false}, - // Answer yes - {"y", true, true}, - {"y", false, true}, - {"y\n", true, true}, - {"y\n", false, true}, - {"y\r\n", true, true}, - {"y\r\n", false, true}, - // Full answer yes - {"yes", true, true}, - {"yes", false, true}, - {"yes\n", true, true}, - {"yes\n", false, true}, - {"yes\r\n", true, true}, - {"yes\r\n", false, true}, - // Answer no - {"n", true, false}, - {"n", false, false}, - {"n\n", true, false}, - {"n\n", false, false}, - {"n\r\n", true, false}, - {"n\r\n", false, false}, - // Full answer no - {"no", true, false}, - {"no", false, false}, - {"no\n", true, false}, - {"no\n", false, false}, - {"no\r\n", true, false}, - {"no\r\n", false, false}, - } - for _, testItem := range testData { - result := AskYesNo( - bytes.NewBufferString(testItem.input), - "message", - testItem.defaultAnswer, - ) - assert.Equalf(t, testItem.expected, result, "when input was %q", testItem.input) - } -} - -func ExamplePrint() { - PrintToError = false - SetOutput(os.Stdout) - _, err := Print("ExamplePrint") - if err != nil { - log.Fatal(err) - } - // Output: ExamplePrint -} - -func TestCanRedirectTermOutput(t *testing.T) { - PrintToError = false - message := "TestCanRedirectTermOutput" - outputBuffer := &bytes.Buffer{} - errorBuffer := &bytes.Buffer{} - SetOutput(outputBuffer) - SetErrorOutput(errorBuffer) - _, err := Print(message) - assert.NoError(t, err) - assert.Equal(t, message, outputBuffer.String()) - assert.Empty(t, errorBuffer.String()) -} - -func TestCanRedirectTermErrorOutput(t *testing.T) { - PrintToError = true - message := "TestCanRedirectTermOutput" - outputBuffer := &bytes.Buffer{} - errorBuffer := &bytes.Buffer{} - SetOutput(outputBuffer) - SetErrorOutput(errorBuffer) - _, err := Print(message) - assert.NoError(t, err) - assert.Equal(t, message, errorBuffer.String()) - assert.Empty(t, outputBuffer.String()) -} diff --git a/term/terminal.go b/term/terminal.go new file mode 100644 index 00000000..860c59a6 --- /dev/null +++ b/term/terminal.go @@ -0,0 +1,208 @@ +package term + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" + + "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/maybe" + "github.com/mattn/go-colorable" + "golang.org/x/term" +) + +// Journey of a message sent to the terminal: +// 1. entrypoint is inputStdout/inputStderr +// 2. data is sent to colorableStdout/colorableStderr +// 3. a copy of the data is also sent to copyStdout/Stderr if enabled (final writer for the copy) +// 4. the data is finally written to the stdin/stdout writers (coming from colorable writers) + +// Terminal gives access to the standard terminal input/output +type Terminal struct { + stdin io.Reader + stdout io.Writer // final writer + stderr io.Writer // final writer + enableColors maybe.Bool + colorableStdout io.Writer // colorable writer + colorableStderr io.Writer // colorable writer + copyStdout io.Writer // stdout duplicate + copyStderr io.Writer // stderr duplicate + inputStdout io.Writer // entry for stdout + inputStderr io.Writer // entry for stderr +} + +func NewTerminal(options ...TerminalOption) *Terminal { + t := &Terminal{ + stdin: os.Stdin, + stdout: os.Stdout, + stderr: os.Stderr, + } + + for _, option := range options { + option(t) + } + + t.colorableStdout = t.getColorableWriter(t.stdout) + t.colorableStderr = t.getColorableWriter(t.stderr) + if t.copyStdout == nil { + // no copy, we send to colorable directly + t.inputStdout = t.colorableStdout + } else { + // send to both the output and the copy + t.inputStdout = io.MultiWriter(t.colorableStdout, t.getColorableWriter(t.copyStdout)) + } + if t.copyStderr == nil { + // no copy, we send to colorable directly + t.inputStderr = t.colorableStderr + } else { + // send to both the output and the copy + t.inputStderr = io.MultiWriter(t.colorableStderr, t.getColorableWriter(t.copyStderr)) + } + + return t +} + +// AskYesNo prompts the user for a message asking for a yes/no answer +func (t *Terminal) AskYesNo(message string, defaultAnswer bool) bool { + if !strings.HasSuffix(message, "?") { + message += "?" + } + var question, input string + if defaultAnswer { + question = "(Y/n)" + input = "y" + } else { + question = "(y/N)" + input = "n" + } + _, _ = t.Printf("%s %s: ", message, question) + scanner := bufio.NewScanner(t.stdin) + if scanner.Scan() { + input = strings.TrimSpace(strings.ToLower(scanner.Text())) + if len(input) > 1 { + // take only the first character + input = input[:1] + } + } + + if input == "" { + return defaultAnswer + } + if input == "y" { + return true + } + return false +} + +// ReadPassword reads a password without echoing it to the terminal. +func (t *Terminal) ReadPassword() (string, error) { + stdin, ok := t.stdin.(*os.File) + if !ok || !isTerminal(stdin) { + return t.readLine() + } + line, err := term.ReadPassword(fdToInt(stdin.Fd())) + if err != nil { + return "", fmt.Errorf("failed to read password: %w", err) + } + return string(line), nil +} + +// ReadLine reads some input +func (t *Terminal) readLine() (string, error) { + buf := bufio.NewReader(t.stdin) + line, err := buf.ReadString('\n') + if err != nil { + return "", fmt.Errorf("failed to read line: %w", err) + } + return strings.TrimSpace(line), nil +} + +// StdoutIsTerminal returns true as stdout is a terminal session +func (t *Terminal) StdoutIsTerminal() bool { + return isTerminalWriter(t.stdout) +} + +// StderrIsTerminal returns true as stderr is a terminal session +func (t *Terminal) StderrIsTerminal() bool { + return isTerminalWriter(t.stderr) +} + +func (t *Terminal) getColorableWriter(w io.Writer) io.Writer { + if file, ok := w.(*os.File); ok && t.enableColors.IsTrueOrUndefined() && (isTerminal(file) || t.enableColors.IsTrue()) { + return colorable.NewColorable(file) + } + return colorable.NewNonColorable(w) +} + +// FlushAllOutput flushes all buffered output (if supported by the underlying Writer). +func (t *Terminal) FlushAllOutput() { + for _, writer := range []io.Writer{ + t.inputStdout, t.inputStderr, + t.copyStdout, t.copyStderr, + t.colorableStdout, t.colorableStderr, + t.stdout, t.stderr, + } { + _, _ = util.FlushWriter(writer) + } +} + +// Print formats using the default formats for its operands and writes to standard output. +// Spaces are added between operands when neither is a string. +// It returns the number of bytes written and any write error encountered. +func (t *Terminal) Print(a ...any) (n int, err error) { + return fmt.Fprint(t.inputStdout, a...) +} + +// Println formats using the default formats for its operands and writes to standard output. +// Spaces are always added between operands and a newline is appended. +// It returns the number of bytes written and any write error encountered. +func (t *Terminal) Println(a ...any) (n int, err error) { + return fmt.Fprintln(t.inputStdout, a...) +} + +// Printf formats according to a format specifier and writes to standard output. +// It returns the number of bytes written and any write error encountered. +func (t *Terminal) Printf(format string, a ...any) (n int, err error) { + return fmt.Fprintf(t.inputStdout, format, a...) +} + +func (t *Terminal) Scanln(a ...any) (n int, err error) { + return fmt.Fscanln(t.stdin, a...) +} + +// Write implements the io.Writer interface, writing to the terminal's stdout. +func (t *Terminal) Write(p []byte) (n int, err error) { + return t.inputStdout.Write(p) +} + +func (t *Terminal) Stdout() io.Writer { + return t.inputStdout +} + +func (t *Terminal) Stderr() io.Writer { + return t.inputStderr +} + +func isTerminalWriter(w io.Writer) bool { + file, ok := w.(*os.File) + if !ok { + return false + } + return isTerminal(file) +} + +func isTerminal(file *os.File) bool { + if file == nil { + return false + } + fd := fdToInt(file.Fd()) + return term.IsTerminal(fd) +} + +func fdToInt(fd uintptr) int { + return int(fd) //nolint:gosec +} + +var _ io.Writer = (*Terminal)(nil) diff --git a/term/terminal_option.go b/term/terminal_option.go new file mode 100644 index 00000000..81da50e3 --- /dev/null +++ b/term/terminal_option.go @@ -0,0 +1,94 @@ +package term + +import ( + "io" + "os" + + "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/maybe" +) + +type TerminalOption func(t *Terminal) + +func WithNoStdin() TerminalOption { + return func(t *Terminal) { + t.stdin = nilReader{} + } +} + +func WithNoStdout() TerminalOption { + return func(t *Terminal) { + t.stdout = io.Discard + } +} + +func WithNoStderr() TerminalOption { + return func(t *Terminal) { + t.stderr = io.Discard + } +} + +func WithStdin(stdin io.Reader) TerminalOption { + if stdin == nil { + return func(t *Terminal) {} + } + return func(t *Terminal) { + t.stdin = stdin + } +} + +func WithStdout(stdout io.Writer) TerminalOption { + if stdout == nil { + return func(t *Terminal) {} + } + if stdout != os.Stdout && stdout != os.Stderr { + stdout = util.NewSyncWriter(stdout) + } + return func(t *Terminal) { + t.stdout = stdout + } +} + +func WithStderr(stderr io.Writer) TerminalOption { + if stderr == nil { + return func(t *Terminal) {} + } + if stderr != os.Stdout && stderr != os.Stderr { + stderr = util.NewSyncWriter(stderr) + } + return func(t *Terminal) { + t.stderr = stderr + } +} + +func WithColors(enable bool) TerminalOption { + return func(t *Terminal) { + t.enableColors = maybe.SetBool(enable) + } +} + +func WithStdoutRecorder(recorder *Recorder) TerminalOption { + return func(t *Terminal) { + t.stdout = recorder.inputWriter + } +} + +func WithStderrRecorder(recorder *Recorder) TerminalOption { + return func(t *Terminal) { + t.stderr = recorder.inputWriter + } +} + +// WithStdoutCopy creates a copy of everything sent to Stdout to `copy` writer. +// Colorisation is independent: the Stdout writer can display colors while the copy to a non terminal won't display colors. +func WithStdoutCopy(w io.Writer) TerminalOption { + return func(t *Terminal) { + t.copyStdout = w + } +} + +func WithStderrCopy(w io.Writer) TerminalOption { + return func(t *Terminal) { + t.copyStderr = w + } +} diff --git a/term/terminal_option_test.go b/term/terminal_option_test.go new file mode 100644 index 00000000..abfeeaf0 --- /dev/null +++ b/term/terminal_option_test.go @@ -0,0 +1,45 @@ +package term + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTerminalNoInput(t *testing.T) { + terminal := NewTerminal(WithNoStdin()) + pwd, err := terminal.ReadPassword() + require.Error(t, err) + assert.Empty(t, pwd) + + var line string + read, err := terminal.Scanln(&line) + require.Error(t, err) + assert.Empty(t, line) + assert.Empty(t, read) +} + +func TestTerminalNoStdout(t *testing.T) { + terminal := NewTerminal(WithNoStdout()) + written, err := terminal.Print("something") + require.NoError(t, err) + assert.Equal(t, 9, written) +} + +func TestTerminalNoStderr(t *testing.T) { + terminal := NewTerminal(WithNoStderr()) + written, err := fmt.Fprint(terminal.Stderr(), "something") + require.NoError(t, err) + assert.Equal(t, 9, written) +} + +func TestTestTerminalStdoutCopy(t *testing.T) { + buffer := new(bytes.Buffer) + terminal := NewTerminal(WithStdoutCopy(buffer)) + _, err := terminal.Printf("%s test", "copy") + require.NoError(t, err) + assert.Equal(t, "copy test", buffer.String()) +} diff --git a/term/terminal_test.go b/term/terminal_test.go new file mode 100644 index 00000000..7ded96a5 --- /dev/null +++ b/term/terminal_test.go @@ -0,0 +1,190 @@ +package term + +import ( + "bytes" + "io" + "log" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTerminalAskYesNo(t *testing.T) { + t.Parallel() + + fixtures := []struct { + input string + defaultAnswer bool + expected bool + }{ + // Empty answer => will follow the defaultAnswer + {"", true, true}, + {"", false, false}, + {"\n", true, true}, + {"\n", false, false}, + {"\r\n", true, true}, + {"\r\n", false, false}, + // Garbage answer => will always return false + {"aa", true, false}, + {"aa", false, false}, + {"aa\n", true, false}, + {"aa\n", false, false}, + {"aa\r\n", true, false}, + {"aa\r\n", false, false}, + // Answer yes + {"y", true, true}, + {"y", false, true}, + {"y\n", true, true}, + {"y\n", false, true}, + {"y\r\n", true, true}, + {"y\r\n", false, true}, + // Full answer yes + {"yes", true, true}, + {"yes", false, true}, + {"yes\n", true, true}, + {"yes\n", false, true}, + {"yes\r\n", true, true}, + {"yes\r\n", false, true}, + // Answer no + {"n", true, false}, + {"n", false, false}, + {"n\n", true, false}, + {"n\n", false, false}, + {"n\r\n", true, false}, + {"n\r\n", false, false}, + // Full answer no + {"no", true, false}, + {"no", false, false}, + {"no\n", true, false}, + {"no\n", false, false}, + {"no\r\n", true, false}, + {"no\r\n", false, false}, + } + for _, fixture := range fixtures { + output := new(bytes.Buffer) + terminal := NewTerminal(WithStdin(bytes.NewBufferString(fixture.input)), WithStdout(output)) + result := terminal.AskYesNo("message", fixture.defaultAnswer) + assert.Contains(t, output.String(), "message? (") + assert.Equalf(t, fixture.expected, result, "when input was %q", fixture.input) + } +} + +func ExamplePrint() { + terminal := NewTerminal() + _, err := terminal.Print("ExampleTerminalPrint") + if err != nil { + log.Fatal(err) + } + // Output: ExampleTerminalPrint +} + +func TestDefaultTerminal(t *testing.T) { + terminal := NewTerminal() + // in a test environment, stdout and stderr are not terminals, so we expect false for both + assert.False(t, terminal.StdoutIsTerminal()) + assert.False(t, terminal.StderrIsTerminal()) +} + +func TestReadPasswordFromBuffer(t *testing.T) { + input := "mysecretpassword\n" + terminal := NewTerminal(WithStdin(bytes.NewBufferString(input))) + password, err := terminal.ReadPassword() + assert.NoError(t, err) + assert.Equal(t, "mysecretpassword", password) +} + +func TestTerminalOutputCapture(t *testing.T) { + buffer := new(bytes.Buffer) + recorder, err := NewRecorder(buffer) + require.NoError(t, err) + + terminal := NewTerminal(WithStdoutRecorder(recorder)) + written, err := terminal.Print("TestTerminalOutputCapture") + require.NoError(t, err) + assert.Equal(t, 25, written) + + err = recorder.Close() + require.NoError(t, err) + assert.Equal(t, "TestTerminalOutputCapture", buffer.String()) +} + +// ansiText is a string with ANSI color codes that should be passed through or stripped depending on the writer. +const ansiText = "\x1b[31mhello\x1b[0m" +const plainText = "hello" + +func TestGetColorableWriterWithNonFileWriterReturnsNonColorable(t *testing.T) { + t.Parallel() + // A bytes.Buffer is not an *os.File, so getColorableWriter always wraps it as NonColorable. + buf := &bytes.Buffer{} + terminal := NewTerminal() + writer := terminal.getColorableWriter(buf) + + _, err := writer.Write([]byte(ansiText)) + require.NoError(t, err) + + // NonColorable strips ANSI escape codes. + assert.Equal(t, plainText, buf.String()) +} + +func TestGetColorableWriterWithColorsDisabledReturnsNonColorable(t *testing.T) { + t.Parallel() + // Even for an *os.File, explicitly disabling colors forces a NonColorable writer. + pr, pw, err := os.Pipe() + require.NoError(t, err) + defer pr.Close() + + terminal := NewTerminal(WithColors(false)) + writer := terminal.getColorableWriter(pw) + + _, err = writer.Write([]byte(ansiText)) + require.NoError(t, err) + pw.Close() + + out, err := io.ReadAll(pr) + require.NoError(t, err) + // NonColorable strips ANSI escape codes. + assert.Equal(t, plainText, string(out)) +} + +func TestGetColorableWriterWithColorsUndefinedOnNonTerminalFileReturnsNonColorable(t *testing.T) { + t.Parallel() + // With undefined colors and a non-terminal *os.File (e.g. a pipe), the writer is NonColorable. + pr, pw, err := os.Pipe() + require.NoError(t, err) + defer pr.Close() + + terminal := NewTerminal() // enableColors is undefined + writer := terminal.getColorableWriter(pw) + + _, err = writer.Write([]byte(ansiText)) + require.NoError(t, err) + pw.Close() + + out, err := io.ReadAll(pr) + require.NoError(t, err) + // NonColorable strips ANSI escape codes. + assert.Equal(t, plainText, string(out)) +} + +func TestGetColorableWriterWithColorsEnabledOnNonTerminalFileReturnsColorable(t *testing.T) { + t.Parallel() + // With colors explicitly enabled, getColorableWriter returns a Colorable writer even for a + // non-terminal *os.File, so ANSI codes are passed through unchanged. + pr, pw, err := os.Pipe() + require.NoError(t, err) + defer pr.Close() + + terminal := NewTerminal(WithColors(true)) + writer := terminal.getColorableWriter(pw) + + _, err = writer.Write([]byte(ansiText)) + require.NoError(t, err) + pw.Close() + + out, err := io.ReadAll(pr) + require.NoError(t, err) + // Colorable preserves ANSI escape codes. + assert.Equal(t, ansiText, string(out)) +} diff --git a/term/test/main.go b/term/test/main.go new file mode 100644 index 00000000..20d3028a --- /dev/null +++ b/term/test/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "github.com/creativeprojects/resticprofile/term" + "github.com/creativeprojects/resticprofile/util/ansi" +) + +func main() { + terminal := term.NewTerminal() + terminal.Println(ansi.Bold("colorable terminal")) + terminal.Printf("stdout is terminal: %v, stderr is terminal: %v\n", terminal.StdoutIsTerminal(), terminal.StderrIsTerminal()) + + terminal = term.NewTerminal(term.WithColors(false)) + terminal.Println(ansi.Bold("non colorable terminal")) + terminal.Printf("stdout is terminal: %v, stderr is terminal: %v\n", terminal.StdoutIsTerminal(), terminal.StderrIsTerminal()) +} diff --git a/update.go b/update.go index ed89a0c6..2226fab5 100644 --- a/update.go +++ b/update.go @@ -6,7 +6,6 @@ import ( "context" "errors" "fmt" - "io" "os" "runtime" @@ -32,19 +31,19 @@ func init() { config.ExcludeProfileSection(def.name) } -func selfUpdate(_ io.Writer, ctx commandContext) error { +func selfUpdate(ctx commandContext) error { quiet := ctx.flags.quiet if !quiet && len(ctx.request.arguments) > 0 && (ctx.request.arguments[0] == "-q" || ctx.request.arguments[0] == "--quiet") { quiet = true } - err := confirmAndSelfUpdate(quiet, ctx.flags.verbose, version, true) + err := confirmAndSelfUpdate(ctx.terminal, quiet, ctx.flags.verbose, version, true) if err != nil { return err } return nil } -func confirmAndSelfUpdate(quiet, debug bool, version string, prerelease bool) error { +func confirmAndSelfUpdate(terminal *term.Terminal, quiet, debug bool, version string, prerelease bool) error { if debug { selfupdate.SetLogger(clog.NewStandardLogger(clog.LevelDebug, clog.GetDefaultLogger())) } @@ -70,8 +69,8 @@ func confirmAndSelfUpdate(quiet, debug bool, version string, prerelease bool) er } // don't ask in quiet mode - if !quiet && !term.AskYesNo(os.Stdin, fmt.Sprintf("Do you want to update to version %s", latest.Version()), true) { - term.Println("Never mind") + if !quiet && !terminal.AskYesNo(fmt.Sprintf("Do you want to update to version %s", latest.Version()), true) { + terminal.Println("Never mind") return nil } diff --git a/update_test.go b/update_test.go index ddf21e93..a17d097c 100644 --- a/update_test.go +++ b/update_test.go @@ -7,6 +7,7 @@ import ( "github.com/creativeprojects/clog" "github.com/creativeprojects/go-selfupdate" + "github.com/creativeprojects/resticprofile/term" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,7 +20,7 @@ func TestUpdate(t *testing.T) { clog.SetTestLog(t) defer clog.CloseTestLog() - err := confirmAndSelfUpdate(true, true, "0.0.1", false) + err := confirmAndSelfUpdate(term.NewTerminal(), true, true, "0.0.1", false) require.ErrorIsf(t, err, selfupdate.ErrExecutableNotFoundInArchive, "error returned isn't wrapping %q but is instead: %q", selfupdate.ErrExecutableNotFoundInArchive, err) assert.Contains(t, err.Error(), "resticprofile.test") } diff --git a/util/ansi/colors.go b/util/ansi/colors.go new file mode 100644 index 00000000..748afcc6 --- /dev/null +++ b/util/ansi/colors.go @@ -0,0 +1,39 @@ +package ansi + +import ( + "strings" + + "github.com/fatih/color" +) + +var ( + bold = color.New(color.Bold) + Bold = bold.SprintFunc() + cyan = color.New(color.FgCyan) + Cyan = cyan.SprintFunc() + gray = New256FgColor(243) + Gray = gray.SprintFunc() + green = color.New(color.FgGreen) + Green = green.SprintFunc() + yellow = color.New(color.FgYellow) + Yellow = yellow.SprintFunc() + underline = color.New(color.Underline) + Underline = underline.Sprint() +) + +// New256FgColor return a new xterm 256 (8bit) foreground color +func New256FgColor(code uint8) *color.Color { + return color.New(38, 5, color.Attribute(code)) +} + +// New256BgColor return a new xterm 256 (8bit) background color +func New256BgColor(code uint8) *color.Color { + return color.New(48, 5, color.Attribute(code)) +} + +func ColorSequence(fn func(a ...any) string) (start, stop string) { + s := fn("||") + before, after, _ := strings.Cut(s, "||") + start, stop = before, after + return +} diff --git a/util/ansi/colors_test.go b/util/ansi/colors_test.go new file mode 100644 index 00000000..03513d57 --- /dev/null +++ b/util/ansi/colors_test.go @@ -0,0 +1,19 @@ +package ansi + +import ( + "strings" + "testing" + + "github.com/fatih/color" + "github.com/stretchr/testify/assert" +) + +func Test256Colors(t *testing.T) { + test := func(color *color.Color, expected string) { + color.EnableColor() + seq := strings.Split(color.Sprint("||"), "||")[0] + assert.Equal(t, expected, seq) + } + test(New256FgColor(10), Sequence('m', 38, 5, 10)) + test(New256BgColor(10), Sequence('m', 48, 5, 10)) +} diff --git a/util/ansi/escape.go b/util/ansi/escape.go new file mode 100644 index 00000000..5ca14831 --- /dev/null +++ b/util/ansi/escape.go @@ -0,0 +1,36 @@ +package ansi + +import ( + "fmt" + "strings" +) + +const ( + EscapeByte = 0x1b + Escape = "\x1b" + ClearLine = Escape + "[2K" + Reset = Escape + "[0m" +) + +// Sequence builds an arbitrary escape sequence +func Sequence(terminator byte, attributes ...any) string { + seq := strings.Builder{} + seq.WriteString(Escape) + seq.WriteByte('[') + for i, attribute := range attributes { + if i > 0 { + seq.WriteByte(';') + } + _, _ = fmt.Fprint(&seq, attribute) + } + seq.WriteByte(terminator) + return seq.String() +} + +// CursorUpLeftN creates the escape sequence to move the cursor left and up N lines +func CursorUpLeftN(lines int) string { + if lines < 0 { + lines = 0 + } + return Sequence('F', lines) +} diff --git a/util/ansi/escape_test.go b/util/ansi/escape_test.go new file mode 100644 index 00000000..4b25be8f --- /dev/null +++ b/util/ansi/escape_test.go @@ -0,0 +1,14 @@ +package ansi + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCursorUpLeftN(t *testing.T) { + assert.Equal(t, Escape+"[0F", CursorUpLeftN(-1)) + assert.Equal(t, Escape+"[0F", CursorUpLeftN(0)) + assert.Equal(t, Escape+"[1F", CursorUpLeftN(1)) + assert.Equal(t, Escape+"[2F", CursorUpLeftN(2)) +} diff --git a/util/ansi/linewriter.go b/util/ansi/linewriter.go new file mode 100644 index 00000000..b9ed0297 --- /dev/null +++ b/util/ansi/linewriter.go @@ -0,0 +1,110 @@ +package ansi + +import ( + "io" + "unicode/utf8" +) + +type lineLengthWriter struct { + writer io.Writer + tokens []byte + maxLineLength, lastWhite, breakLength, lineLength int + invisibleLength, lastWhiteInvisibleLength int + inAnsi bool +} + +// NewLineLengthWriter return an io.Writer that limits the max line length, adding line breaks ('\n') as needed. +// The writer detects the right most column (consecutive whitespace) and aligns content if possible. +// UTF sequences are counted as single character and ANSI escape sequences are not counted at all. +func NewLineLengthWriter(writer io.Writer, maxLineLength int) io.Writer { + return &lineLengthWriter{ + tokens: []byte{' ', '\n'}, + writer: writer, + maxLineLength: maxLineLength, + } +} + +func (l *lineLengthWriter) visibleLineLength() int { return l.lineLength - l.invisibleLength } + +func (l *lineLengthWriter) Write(p []byte) (n int, err error) { + var written int + offset := l.lineLength + + for i := 0; i < len(p); i++ { + l.lineLength++ + ws := p[i] == l.tokens[0] // whitespace + br := p[i] == l.tokens[1] // linebreak + + // don't count ansi control sequences + if l.inAnsi = l.inAnsi || p[i] == EscapeByte; l.inAnsi { + terminator := (p[i] >= 'a' && p[i] <= 'z') || (p[i] >= 'A' && p[i] <= 'Z') + l.inAnsi = !terminator + l.invisibleLength++ + continue + } + + // count UTF sequence as one character + if p[i] >= utf8.RuneSelf { + if !utf8.RuneStart(p[i]) { + l.invisibleLength++ + } + continue + } + + if !br && l.visibleLineLength() > l.maxLineLength && l.lastWhite-offset > 0 { + lastWhiteIndex := l.lastWhite - offset - 1 + remainder := i - lastWhiteIndex + + if written, err = l.writer.Write(p[:lastWhiteIndex]); err == nil { + p = p[lastWhiteIndex+1:] + i = remainder - 1 + n += written + 1 + + _, _ = l.writer.Write(l.tokens[1:]) // write break (instead of WS at lastWhiteIndex) + for range l.breakLength { + _, _ = l.writer.Write(l.tokens[0:1]) // fill spaces for alignment + } + + l.lineLength = l.breakLength + remainder + l.lastWhite = l.breakLength + offset = l.breakLength + + l.invisibleLength -= l.lastWhiteInvisibleLength + l.lastWhiteInvisibleLength = 0 + } else { + return + } + } + + if ws { + if l.lastWhite == l.lineLength-1 && l.visibleLineLength() < l.maxLineLength*2/3 { + l.breakLength = l.visibleLineLength() + } + l.lastWhite = l.lineLength + l.lastWhiteInvisibleLength = l.invisibleLength + + } else if br { + if written, err = l.writer.Write(p[:i+1]); err == nil { + p = p[i+1:] + i = -1 + n += written + + l.lineLength = 0 + l.lastWhite = 0 + l.breakLength = 0 + offset = 0 + + l.invisibleLength = 0 + l.lastWhiteInvisibleLength = 0 + } else { + return + } + } + } + + // write remainder + if written, err = l.writer.Write(p); err == nil { + n += written + } + return +} diff --git a/util/ansi/linewriter_test.go b/util/ansi/linewriter_test.go new file mode 100644 index 00000000..62146d4e --- /dev/null +++ b/util/ansi/linewriter_test.go @@ -0,0 +1,280 @@ +package ansi + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/fatih/color" + "github.com/stretchr/testify/assert" +) + +var ansiColor = func() (c *color.Color) { + c = color.New(color.FgCyan) + c.EnableColor() + return +}() + +var colored = ansiColor.SprintFunc() + +func TestLineLengthWriter(t *testing.T) { + tests := []struct { + input, expected string + chunks, scale int + }{ + // test non-breakable + {input: strings.Repeat("-", 50), expected: strings.Repeat("-", 50), chunks: 15}, + + // test breakable without columns + { + input: strings.Repeat("word ", 20), + expected: "" + + strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + + strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + + strings.Repeat("word ", 4), + chunks: 5, scale: 6, + }, + + // test breakable with ANSI color + { + input: strings.Repeat(colored("word "), 20), + expected: "" + + strings.Repeat(colored("word "), 7) + colored("word\n") + + strings.Repeat(colored("word "), 7) + colored("word\n") + + strings.Repeat(colored("word "), 4), + }, + + // test breakable with 2 columns + { + input: "word word word word " + + strings.Repeat("word ", 20), + expected: "" + + "word word word word " + + "word word word\n" + + strings.Repeat(" word word word\n", 5) + + " word word ", + chunks: 3, scale: 15, + }, + + // test breakable with 2 columns and ANSI color + { + input: colored("word word word word ") + + strings.Repeat(colored("word "), 20), + expected: "" + + colored("word word word word ") + + colored("word ") + colored("word ") + colored("word\n ") + + strings.Repeat(colored("word ")+colored("word ")+colored("word\n "), 5) + + colored("word ") + colored("word "), + }, + + // test breakable with 2 columns and unicode character + { + input: "w😁rd wo😁d wor😁 😁ord 😁😁😁😁 w😁rd wo😁d wor😁 😁ord 😁😁😁😁", + expected: "w😁rd wo😁d wor😁 😁ord 😁😁😁😁 w😁rd wo😁d \n" + + " wor😁 😁ord 😁😁😁😁", + }, + { + input: "word word word word word word word word word word", + expected: "word word word word word word word \n" + + " word word word", + }, + + // test breakable with 2 columns, colors and unicode character + { + input: colored("w😁rd wo😁d wor😁 😁ord ") + + strings.Repeat(colored("wor😁 ")+colored("😁ord "), 2), + expected: "" + + colored("w😁rd wo😁d wor😁 😁ord ") + + colored("wor😁 ") + colored("😁ord ") + colored("wor😁\n ") + + colored("😁ord "), + }, + + // test real-world content + { + input: ` +Usage of resticprofile: + resticprofile [resticprofile flags] [profile name.][restic command] [restic flags] + resticprofile [resticprofile flags] [profile name.][resticprofile command] [command specific flags] + +resticprofile flags: + -c, --config string configuration file (default "profiles") + --dry-run display the restic commands instead of running them + -f, --format string file format of the configuration (default is to use the file extension) + -h, --help display this help + --lock-wait duration wait up to duration to acquire a lock (syntax "1h5m30s") + -l, --log string logs to a target instead of the console + -n, --name string profile name (default "default") + --no-ansi disable ansi control characters (disable console colouring) + --no-lock skip profile lock file + --no-prio don't set any priority on load: used when started from a service that has already set the priority + -q, --quiet display only warnings and errors + --theme string console colouring theme (dark, light, none) (default "light") + --trace display even more debugging information + -v, --verbose display some debugging information + -w, --wait wait at the end until the user presses the enter key + + +resticprofile own commands: + help display help (run in verbose mode for detailed information) + version display version (run in verbose mode for detailed information) + self-update update to latest resticprofile (use -q/--quiet flag to update without confirmation) + profiles display profile names from the configuration file + show show all the details of the current profile + schedule schedule jobs from a profile (use --all flag to schedule all jobs of all profiles) + unschedule remove scheduled jobs of a profile (use --all flag to unschedule all profiles) + status display the status of scheduled jobs (use --all flag for all profiles) + generate generate resources (--random-key [size], --bash-completion & --zsh-completion) + +Documentation available at https://creativeprojects.github.io/resticprofile/ +`, + expected: ` +Usage of resticprofile: + resticprofile [resticprofile flags] + [profile name.][restic command] + [restic flags] + resticprofile [resticprofile flags] + [profile name.][resticprofile + command] [command specific flags] + +resticprofile flags: + -c, --config string + configuration + file (default + "profiles") + --dry-run display + the restic + commands + instead of + running them + -f, --format string file + format of the + configuration + (default is to + use the file + extension) + -h, --help display + this help + --lock-wait duration wait up to + duration to acquire a lock + (syntax "1h5m30s") + -l, --log string logs to a + target instead + of the console + -n, --name string profile + name (default + "default") + --no-ansi disable + ansi control + characters + (disable + console + colouring) + --no-lock skip + profile lock + file + --no-prio don't set + any priority + on load: used + when started + from a service + that has + already set + the priority + -q, --quiet display + only warnings + and errors + --theme string console + colouring + theme (dark, + light, none) + (default + "light") + --trace display + even more + debugging + information + -v, --verbose display + some debugging + information + -w, --wait wait at + the end until + the user + presses the + enter key + + +resticprofile own commands: + help display help (run in + verbose mode for + detailed information) + version display version (run + in verbose mode for + detailed information) + self-update update to latest + resticprofile (use + -q/--quiet flag to + update without + confirmation) + profiles display profile names + from the configuration + file + show show all the details + of the current profile + schedule schedule jobs from a + profile (use --all + flag to schedule all + jobs of all profiles) + unschedule remove scheduled jobs + of a profile (use + --all flag to + unschedule all + profiles) + status display the status of + scheduled jobs (use + --all flag for all + profiles) + generate generate resources + (--random-key [size], + --bash-completion & + --zsh-completion) + +Documentation available at +https://creativeprojects.github.io/resticprofile/ +`, + }, + } + + for i, test := range tests { + if test.scale == 0 { + test.scale = 1 + } + for chunkSize := 0; chunkSize <= test.chunks; chunkSize++ { + t.Run(fmt.Sprintf("%d-%d", i, chunkSize), func(t *testing.T) { + buffer := bytes.Buffer{} + writer := NewLineLengthWriter(&buffer, 40) + input := []byte(test.input) + + var ( + n int + err error + ) + switch chunkSize { + case 0: + n, err = writer.Write(input) + assert.Equal(t, len(input), n) + default: + for len(input) > 0 && err == nil { + length := min(test.scale*chunkSize, len(input)) + n, err = writer.Write(input[:length]) + assert.Equal(t, length, n) + input = input[length:] + } + } + + assert.Nil(t, err) + assert.Equal(t, test.expected, buffer.String()) + }) + } + } +} diff --git a/util/ansi/runes.go b/util/ansi/runes.go new file mode 100644 index 00000000..da578842 --- /dev/null +++ b/util/ansi/runes.go @@ -0,0 +1,26 @@ +package ansi + +// RunesLength returns the visible content length (and index at the length) +func RunesLength(src []rune, maxLength int) (length, index int) { + esc := rune(Escape[0]) + inEsc := false + index = -1 + for i, r := range src { + if r == esc { + inEsc = true + } else if inEsc { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + inEsc = false + } + } else { + if length == maxLength { + index = i + } + length++ + } + } + if index < 0 { + index = len(src) + } + return +} diff --git a/util/ansi/runes_test.go b/util/ansi/runes_test.go new file mode 100644 index 00000000..c00452b2 --- /dev/null +++ b/util/ansi/runes_test.go @@ -0,0 +1,40 @@ +package ansi + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunesLength(t *testing.T) { + tests := []struct { + input []rune + max, index, length int + }{ + {input: []rune{}, index: 0, length: 0, max: -1}, + {input: []rune(""), index: 0, length: 0, max: -1}, + {input: []rune(ClearLine + ""), index: len(ClearLine), length: 0, max: -1}, + {input: []rune(ClearLine + "" + ClearLine), index: 2 * len(ClearLine), length: 0, max: -1}, + {input: []rune(ClearLine + "◷" + ClearLine), index: 1 + 2*len(ClearLine), length: 1, max: -1}, + + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: len(ClearLine), length: 3, max: 0}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: 1 + len(ClearLine), length: 3, max: 1}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: 2 + len(ClearLine), length: 3, max: 2}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: 3 + 2*len(ClearLine), length: 3, max: 3}, + + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 3 + 2*len(ClearLine), length: 6, max: 3}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 4 + 2*len(ClearLine), length: 6, max: 4}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 5 + 2*len(ClearLine), length: 6, max: 5}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 6 + 2*len(ClearLine), length: 6, max: 6}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 6 + 2*len(ClearLine), length: 6, max: 7}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 6 + 2*len(ClearLine), length: 6, max: -1}, + } + for idx, test := range tests { + t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) { + length, index := RunesLength(test.input, test.max) + assert.Equal(t, test.length, length, "length") + assert.Equal(t, test.index, index, "index") + }) + } +} diff --git a/util/write/append.go b/util/write/append.go new file mode 100644 index 00000000..cb263a29 --- /dev/null +++ b/util/write/append.go @@ -0,0 +1,39 @@ +package write + +import "io" + +// WriterAppendFunc is called for every input byte when appending it to the output buffer (is similar to buf = append(buf, byte)) +type WriterAppendFunc func(dst []byte, c byte) []byte + +type Append struct { + appender WriterAppendFunc + next io.Writer +} + +func NewAppend(w io.Writer, fn WriterAppendFunc) *Append { + return &Append{ + appender: fn, + next: w, + } +} + +func (a *Append) Write(data []byte) (int, error) { + if a.appender == nil { + return a.next.Write(data) + } + dst := make([]byte, 0, len(data)*2) + for _, c := range data { + dst = a.appender(dst, c) + } + return a.next.Write(dst) +} + +// Close will close the destination writer if accessible +func (a *Append) Close() error { + if closer, ok := a.next.(io.Closer); ok { + return closer.Close() + } + return nil +} + +var _ io.WriteCloser = &Append{} diff --git a/util/write/append_test.go b/util/write/append_test.go new file mode 100644 index 00000000..502b1891 --- /dev/null +++ b/util/write/append_test.go @@ -0,0 +1,51 @@ +package write + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAppendLinesNoAppender(t *testing.T) { + buffer := new(bytes.Buffer) + a := NewAppend(buffer, nil) + _, _ = a.Write([]byte("a\n")) + _, _ = a.Write([]byte("b\n")) + _, _ = a.Write([]byte("c\n")) + + assert.Equal(t, "a\nb\nc\n", buffer.String()) +} + +func TestAppendLinesWithAppender(t *testing.T) { + buffer := new(bytes.Buffer) + appender := func(dst []byte, c byte) []byte { + switch c { + case '\n': + return append(dst, '\r', '\n') // normalize to CRLF on Windows + case '\r': + return dst + } + return append(dst, c) + } + + a := NewAppend(buffer, appender) + _, _ = a.Write([]byte("a\n")) + _, _ = a.Write([]byte("b\n")) + _, _ = a.Write([]byte("c\n")) + + assert.Equal(t, "a\r\nb\r\nc\r\n", buffer.String()) +} + +func TestAppendCallsCloseOnTarget(t *testing.T) { + called := 0 + mockClose := newMockWriteCloser(nil, func() error { + called++ + return nil + }) + + a := NewAppend(mockClose, nil) + a.Close() + + assert.Equal(t, 1, called) +} diff --git a/util/write/async.go b/util/write/async.go new file mode 100644 index 00000000..e1876143 --- /dev/null +++ b/util/write/async.go @@ -0,0 +1,130 @@ +package write + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +const ( + asyncWriterDataChanSize = 64 + asyncWriterFlushChanSize = 16 +) + +var ErrAlreadyClosed = errors.New("writer already closed") + +type Async struct { + handler io.Writer + interval time.Duration + data chan []byte + flusher chan chan error + done chan struct{} + systemGroup sync.WaitGroup + closeOnce sync.Once + closed atomic.Bool + flusherClosed atomic.Bool +} + +// NewAsync creates a writer that accumulates Write requests and writes them at a fixed rate (every 250 ms by default) +func NewAsync(handler io.Writer, options ...AsyncOption) *Async { + w := &Async{ + handler: handler, + interval: 250 * time.Millisecond, + data: make(chan []byte, asyncWriterDataChanSize), + flusher: make(chan chan error, asyncWriterFlushChanSize), + done: make(chan struct{}), // channel closed after the first call to Close() + } + for _, option := range options { + option(w) + } + w.systemGroup.Go(func() { + w.intervalFlush() + }) + w.systemGroup.Go(func() { + w.recvFlush() + }) + return w +} + +func (w *Async) intervalFlush() { + ticker := time.NewTicker(w.interval) + for { + select { + case <-ticker.C: + ticker.Stop() // don't keep piling up if the flusher channel is already full + w.flusher <- nil + ticker.Reset(w.interval) + case <-w.done: + ticker.Stop() + return + } + } +} + +func (w *Async) recvFlush() { + for done := range w.flusher { + err := w.flush() + if done != nil { // some calls don't need to wait for the answer + done <- err + } + } +} + +// Close the writer. Any more call to Write will be ignored. +func (w *Async) Close() error { + var err error + w.closeOnce.Do(func() { + w.closed.Store(true) + close(w.done) + err = w.Flush() + w.flusherClosed.Store(true) + close(w.flusher) + w.systemGroup.Wait() + if closer, ok := w.handler.(io.Closer); ok { + err = errors.Join(err, closer.Close()) + } + }) + return err +} + +func (w *Async) Flush() error { + if w.flusherClosed.Load() { + return fmt.Errorf("cannot write: %w", ErrAlreadyClosed) + } + done := make(chan error) + w.flusher <- done + // wait until the flusher is done + err := <-done + close(done) + return err +} + +func (w *Async) flush() error { + for { + // keep reading from the channel until it's empty + select { + case data := <-w.data: + _, err := w.handler.Write(data) + if err != nil { + return err + } + default: + return nil + } + } +} + +// Write asynchronously to the handler +func (w *Async) Write(data []byte) (n int, err error) { + if w.closed.Load() { + return 0, fmt.Errorf("cannot write: %w", ErrAlreadyClosed) + } + + buffer := make([]byte, len(data)) + n = copy(buffer, data) + w.data <- buffer + return n, nil +} diff --git a/util/write/async_option.go b/util/write/async_option.go new file mode 100644 index 00000000..9bc89692 --- /dev/null +++ b/util/write/async_option.go @@ -0,0 +1,12 @@ +package write + +import ( + "time" +) + +type AsyncOption func(writer *Async) + +// WithWriteInterval sets the interval at which writes happen at least when data is pending +func WithWriteInterval(duration time.Duration) AsyncOption { + return func(writer *Async) { writer.interval = duration } +} diff --git a/util/write/async_test.go b/util/write/async_test.go new file mode 100644 index 00000000..31090a6f --- /dev/null +++ b/util/write/async_test.go @@ -0,0 +1,110 @@ +package write + +import ( + "bytes" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAsyncWriter(t *testing.T) { + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + n, err := w.Write([]byte("hello world")) + require.NoError(t, err) + assert.Equal(t, 11, n) + + err = w.Close() + require.NoError(t, err) + + assert.Equal(t, "hello world", buffer.String()) +} + +func TestAsyncWriterFlushAfterClose(t *testing.T) { + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + n, err := w.Write([]byte("hello world")) + require.NoError(t, err) + assert.Equal(t, 11, n) + + require.NoError(t, w.Flush()) + require.NoError(t, w.Close()) + + require.ErrorIs(t, w.Flush(), ErrAlreadyClosed) + + assert.Equal(t, "hello world", buffer.String()) +} + +func TestAsyncWriteMoreThanChannelSize(t *testing.T) { + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + for range asyncWriterDataChanSize + 1 { + n, err := w.Write([]byte("aaa")) + require.NoError(t, err) + assert.Equal(t, 3, n) + } + + err := w.Close() + require.NoError(t, err) + + assert.Equal(t, strings.Repeat("aaa", asyncWriterDataChanSize+1), buffer.String()) +} + +func TestAsyncWriteInParallelAndClose(t *testing.T) { + repeat := 10 + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + wg := new(sync.WaitGroup) + for range repeat { + wg.Go(func() { + _, _ = w.Write([]byte("aa")) + }) + } + + require.NoError(t, w.Close()) + wg.Wait() +} + +func TestAsyncWriteInParallelAndWaitBeforeClosing(t *testing.T) { + repeat := 10 + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + wg := new(sync.WaitGroup) + for range repeat { + wg.Go(func() { + _, _ = w.Write([]byte("aa")) + }) + } + + wg.Wait() + require.NoError(t, w.Close()) + + assert.Equal(t, strings.Repeat("aa", repeat), buffer.String()) +} + +func TestAsyncWriteBigBuffersInParallelAndWaitBeforeClosing(t *testing.T) { + repeat := 100 + bufferSize := 1024 * 1024 + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + wg := new(sync.WaitGroup) + for range repeat { + wg.Go(func() { + buffer := make([]byte, bufferSize) + _, _ = w.Write(buffer) + }) + } + + wg.Wait() + require.NoError(t, w.Close()) + assert.Equal(t, repeat*bufferSize, buffer.Len()) +} diff --git a/util/write/file.go b/util/write/file.go new file mode 100644 index 00000000..9ccb0784 --- /dev/null +++ b/util/write/file.go @@ -0,0 +1,129 @@ +package write + +import ( + "errors" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/creativeprojects/resticprofile/platform" +) + +var ErrAttemptToWriteOnClosedFile = errors.New("cannot write to a closed or unopened file") + +type File struct { + filename string + perm os.FileMode + flag int + keepOpen bool + keepOpenTimeout time.Duration + handle *os.File + mutex sync.Mutex + timer *time.Timer + timerMutex sync.Mutex + // stats + fileOpenCount atomic.Int32 + fileCloseCount atomic.Int32 +} + +func NewFile(filename string, options ...FileOption) (f *File, err error) { + f = &File{ + filename: filename, + perm: 0644, + flag: os.O_WRONLY | os.O_APPEND | os.O_CREATE, + keepOpen: !platform.IsWindows(), + keepOpenTimeout: 10 * time.Millisecond, + } + + for _, option := range options { + option(f) + } + + err = f.open() + if !f.keepOpen { + defer func() { + err = errors.Join(err, f.Close()) + }() + } + return +} + +func (f *File) open() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if f.handle != nil { + return nil + } + var err error + f.fileOpenCount.Add(1) + f.handle, err = os.OpenFile(f.filename, f.flag, f.perm) + return err +} + +func (f *File) Close() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if f.handle != nil { + f.fileCloseCount.Add(1) + err := f.handle.Close() + f.handle = nil + return err + } + return nil +} + +func (f *File) Flush() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if f.handle != nil { + return f.handle.Sync() + } + return nil +} + +func (f *File) Write(data []byte) (n int, err error) { + if !f.keepOpen { + f.stopCloseTimer() + err := f.open() + if err != nil { + return 0, err + } + defer f.resetCloseTimer() + } + + if f.handle == nil { + return 0, ErrAttemptToWriteOnClosedFile + } + + n, err = f.handle.Write(data) + return +} + +func (f *File) stopCloseTimer() { + f.timerMutex.Lock() + defer f.timerMutex.Unlock() + if f.timer != nil { + f.timer.Stop() + } +} + +func (f *File) resetCloseTimer() { + f.timerMutex.Lock() + defer f.timerMutex.Unlock() + + if f.timer != nil { + f.timer.Stop() + } + f.timer = time.AfterFunc(f.keepOpenTimeout, func() { + _ = f.Close() + }) +} + +// stats returns the number of times the file was opened and closed +func (f *File) stats() (int32, int32) { + return f.fileOpenCount.Load(), f.fileCloseCount.Load() +} diff --git a/util/write/file_option.go b/util/write/file_option.go new file mode 100644 index 00000000..e59078ed --- /dev/null +++ b/util/write/file_option.go @@ -0,0 +1,34 @@ +package write + +import ( + "os" + "time" +) + +type FileOption func(f *File) + +// WithFileKeepOpen toggles whether the file is kept open between writes. Defaults to true for all OS except Windows. +func WithFileKeepOpen(keepOpen bool) FileOption { + return func(f *File) { f.keepOpen = keepOpen } +} + +// WithFileKeepOpenTimeout will automatically close the file after there was no more write during `timeout`. +// It defaults to 10ms. +func WithFileKeepOpenTimeout(timeout time.Duration) FileOption { + return func(f *File) { f.keepOpenTimeout = timeout } +} + +// WithFilePerm sets file perms to apply when creating the file +func WithFilePerm(perm os.FileMode) FileOption { + return func(f *File) { f.perm = perm } +} + +// WithFileFlag sets file open flags +func WithFileFlag(flag int) FileOption { + return func(f *File) { f.flag = flag } +} + +// WithFileTruncate enables that existing files are truncated +func WithFileTruncate() FileOption { + return func(f *File) { f.flag |= os.O_TRUNC } +} diff --git a/util/write/file_test.go b/util/write/file_test.go new file mode 100644 index 00000000..646de17f --- /dev/null +++ b/util/write/file_test.go @@ -0,0 +1,140 @@ +package write + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileDefaultOption(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(true), WithFileKeepOpenTimeout(1*time.Millisecond)) + require.NoError(t, err) + + n, err := w.Write([]byte("hello world")) + assert.NoError(t, err) + assert.Equal(t, 11, n) + time.Sleep(2 * time.Millisecond) + + err = w.Close() + assert.NoError(t, err) + + opened, closed := w.stats() + assert.Equal(t, int32(1), opened) + assert.Equal(t, int32(1), closed) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) +} + +func TestFileCloseAfterWrite(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(false), WithFileKeepOpenTimeout(1*time.Millisecond)) + require.NoError(t, err) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + time.Sleep(2 * time.Millisecond) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello", string(content)) + + n, err = w.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + time.Sleep(2 * time.Millisecond) + + content, err = os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) + + err = w.Close() + assert.NoError(t, err) + + opened, closed := w.stats() + assert.Equal(t, int32(3), opened) // 1 during instantiation + 1 for each write + assert.Equal(t, int32(3), closed) // 1 during instantiation + 1 for each write +} + +func TestFileNoTimeToCloseAfterWrite(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(false), WithFileKeepOpenTimeout(1*time.Second)) + require.NoError(t, err) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello", string(content)) + + n, err = w.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + + content, err = os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) + + err = w.Close() + assert.NoError(t, err) + + opened, closed := w.stats() + assert.Equal(t, int32(2), opened) // 1 during instantiation + 1 for both write + assert.Equal(t, int32(2), closed) // 1 during instantiation + 1 for both write +} + +func TestFileCanFlush(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename) + require.NoError(t, err) + assert.NoError(t, w.Flush()) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.NoError(t, w.Flush()) + + n, err = w.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + assert.NoError(t, w.Flush()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) + + assert.NoError(t, w.Flush()) + err = w.Close() + assert.NoError(t, err) + + assert.NoError(t, w.Flush()) +} + +func TestFileWriteOnClosedFile(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(true)) + require.NoError(t, err) + assert.NoError(t, w.Close()) + + _, err = w.Write([]byte("aaa")) + assert.ErrorIs(t, err, ErrAttemptToWriteOnClosedFile) +} diff --git a/util/write/mock_write_closer_test.go b/util/write/mock_write_closer_test.go new file mode 100644 index 00000000..0c5e1e83 --- /dev/null +++ b/util/write/mock_write_closer_test.go @@ -0,0 +1,25 @@ +package write + +import "io" + +type mockWriteCloser struct { + writer func(data []byte) (int, error) + closer func() error +} + +func newMockWriteCloser(writer func(data []byte) (int, error), closer func() error) mockWriteCloser { + return mockWriteCloser{ + writer: writer, + closer: closer, + } +} + +func (m mockWriteCloser) Write(data []byte) (int, error) { + return m.writer(data) +} + +func (m mockWriteCloser) Close() error { + return m.closer() +} + +var _ io.WriteCloser = &Append{} diff --git a/util/writer.go b/util/writer.go new file mode 100644 index 00000000..58a02283 --- /dev/null +++ b/util/writer.go @@ -0,0 +1,73 @@ +package util + +import ( + "io" + "sync" +) + +// Flusher allows a Writer to declare it may buffer content that can be flushed +type Flusher interface { + // Flush writes any pending bytes to output + Flush() error +} + +// FlushWriter attempts to flush a writer if it implements Flusher +func FlushWriter(writer io.Writer) (flushable bool, err error) { + var f Flusher + if f, flushable = writer.(Flusher); flushable { + err = f.Flush() + } + return +} + +// NewSyncWriter creates a new writer that is safe for concurrent use +func NewSyncWriter[W io.Writer](writer W) SyncWriter[W] { + return NewSyncWriterMutex(writer, new(sync.Mutex)) +} + +// NewSyncWriterMutex creates a new writer that is safe for concurrent use and synced with the specified sync.Mutex +func NewSyncWriterMutex[W io.Writer](writer W, mutex *sync.Mutex) SyncWriter[W] { + return &syncWriter[W]{writer: writer, mutex: mutex} +} + +// SyncWriter implements an io.Writer that is safe for concurrent use +type SyncWriter[W io.Writer] interface { + io.Writer + // Locked provides sync.Mutex locked access to the underlying writer + Locked(fn func(W) error) error +} + +type syncWriter[W io.Writer] struct { + writer io.Writer + mutex *sync.Mutex +} + +func (w *syncWriter[W]) Locked(fn func(writer W) error) error { + w.mutex.Lock() + defer w.mutex.Unlock() + return fn(w.writer.(W)) +} + +func (w *syncWriter[W]) Write(p []byte) (n int, err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + return w.writer.Write(p) +} + +func (w *syncWriter[W]) Flush() (err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + if f, ok := w.writer.(Flusher); ok { + err = f.Flush() + } + return +} + +func (w *syncWriter[W]) Close() (err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + if f, ok := w.writer.(io.Closer); ok { + err = f.Close() + } + return +} diff --git a/util/writer_test.go b/util/writer_test.go new file mode 100644 index 00000000..790c7252 --- /dev/null +++ b/util/writer_test.go @@ -0,0 +1,207 @@ +package util + +import ( + "bytes" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// trackingWriteCloser is a bytes.Buffer that also tracks Close and Flush calls. +type trackingWriteCloser struct { + bytes.Buffer + closed bool + flushed bool + closeErr error + flushErr error +} + +func (t *trackingWriteCloser) Close() error { + t.closed = true + return t.closeErr +} + +func (t *trackingWriteCloser) Flush() error { + t.flushed = true + return t.flushErr +} + +// --- FlushWriter tests --- + +func TestFlushWriterFlushable(t *testing.T) { + tw := &trackingWriteCloser{} + flushable, err := FlushWriter(tw) + assert.True(t, flushable) + assert.NoError(t, err) + assert.True(t, tw.flushed) +} + +func TestFlushWriterNotFlushable(t *testing.T) { + var buf bytes.Buffer + flushable, err := FlushWriter(&buf) + assert.False(t, flushable) + assert.NoError(t, err) +} + +func TestFlushWriterFlushError(t *testing.T) { + sentinel := errors.New("flush error") + tw := &trackingWriteCloser{flushErr: sentinel} + flushable, err := FlushWriter(tw) + assert.True(t, flushable) + assert.Equal(t, sentinel, err) +} + +// --- syncWriter tests --- + +func TestSyncWriterWrite(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", buf.String()) +} + +func TestSyncWriterMultipleWrites(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + for _, s := range []string{"foo", "bar", "baz"} { + _, err := w.Write([]byte(s)) + require.NoError(t, err) + } + assert.Equal(t, "foobarbaz", buf.String()) +} + +func TestSyncWriterLocked(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + err := w.Locked(func(inner *bytes.Buffer) error { + _, err := inner.WriteString("locked") + return err + }) + assert.NoError(t, err) + assert.Equal(t, "locked", buf.String()) +} + +func TestSyncWriterLockedPropagatesError(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + sentinel := errors.New("locked error") + err := w.Locked(func(_ *bytes.Buffer) error { return sentinel }) + assert.Equal(t, sentinel, err) +} + +func TestSyncWriterFlushWithFlusher(t *testing.T) { + tw := &trackingWriteCloser{} + w := NewSyncWriter(tw) + + // syncWriter exposes Flush via a type assertion at call site + f, ok := w.(Flusher) + require.True(t, ok) + err := f.Flush() + assert.NoError(t, err) + assert.True(t, tw.flushed) +} + +func TestSyncWriterFlushError(t *testing.T) { + sentinel := errors.New("flush error") + tw := &trackingWriteCloser{flushErr: sentinel} + w := NewSyncWriter(tw) + + err := w.(Flusher).Flush() + assert.Equal(t, sentinel, err) +} + +func TestSyncWriterFlushNonFlusher(t *testing.T) { + // bytes.Buffer is not a Flusher; Flush should be a no-op + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + err := w.(Flusher).Flush() + assert.NoError(t, err) +} + +func TestSyncWriterCloseWithCloser(t *testing.T) { + tw := &trackingWriteCloser{} + w := NewSyncWriter(tw) + + type closer interface{ Close() error } + err := w.(closer).Close() + assert.NoError(t, err) + assert.True(t, tw.closed) +} + +func TestSyncWriterCloseError(t *testing.T) { + sentinel := errors.New("close error") + tw := &trackingWriteCloser{closeErr: sentinel} + w := NewSyncWriter(tw) + + type closer interface{ Close() error } + err := w.(closer).Close() + assert.Equal(t, sentinel, err) +} + +func TestSyncWriterCloseNonCloser(t *testing.T) { + // bytes.Buffer is not an io.Closer; Close should be a no-op + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + type closer interface{ Close() error } + err := w.(closer).Close() + assert.NoError(t, err) +} + +func TestSyncWriterMutex(t *testing.T) { + var buf bytes.Buffer + mutex := new(sync.Mutex) + w := NewSyncWriterMutex[*bytes.Buffer](&buf, mutex) + + n, err := w.Write([]byte("shared")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + assert.Equal(t, "shared", buf.String()) +} + +func TestSyncWriterConcurrentWrites(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter[*bytes.Buffer](&buf) + + var wg sync.WaitGroup + for range 50 { + wg.Go(func() { + _, _ = w.Write([]byte("x")) + }) + } + wg.Wait() + assert.Equal(t, 50, buf.Len()) +} + +func TestSyncWriterConcurrentLockedAndWrite(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter[*bytes.Buffer](&buf) + + var wg sync.WaitGroup + for range 10 { + wg.Add(2) + go func() { + defer wg.Done() + _, _ = w.Write([]byte("w")) + }() + go func() { + defer wg.Done() + _ = w.Locked(func(b *bytes.Buffer) error { + _, err := b.WriteString("l") + return err + }) + }() + } + wg.Wait() + assert.Equal(t, 20, buf.Len()) +} diff --git a/wrapper.go b/wrapper.go index be40f235..7364bbfc 100644 --- a/wrapper.go +++ b/wrapper.go @@ -22,7 +22,6 @@ import ( "github.com/creativeprojects/resticprofile/monitor/hook" "github.com/creativeprojects/resticprofile/restic" "github.com/creativeprojects/resticprofile/shell" - "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util" "github.com/creativeprojects/resticprofile/util/collect" ) @@ -438,8 +437,8 @@ func (r *resticWrapper) prepareCommand(command string, args *shell.Args, allowEx rCommand := newShellCommand(binary, arguments, env, r.getShell(), r.dryRun, r.sigChan, r.setPID) rCommand.publicArgs = publicArguments // stdout are stderr are coming from the default terminal (in case they're redirected) - rCommand.stdout = term.GetOutput() - rCommand.stderr = term.GetErrorOutput() + rCommand.stdout = r.ctx.terminal.Stdout() + rCommand.stderr = r.ctx.terminal.Stderr() rCommand.streamError = r.profile.StreamError rCommand.dir = dir @@ -545,7 +544,7 @@ func (r *resticWrapper) runCommand(command string) error { if len(r.progress) > 0 { if r.profile.Backup.ExtendedStatus { rCommand.scanOutput = shell.ScanBackupJson - } else if !term.OsStdoutIsTerminal() { + } else if !r.ctx.terminal.StdoutIsTerminal() { // restic detects its output is not a terminal and no longer displays the monitor. // Scan plain output only if resticprofile is not run from a terminal (e.g. schedule) rCommand.scanOutput = shell.ScanBackupPlain @@ -629,9 +628,9 @@ func (r *resticWrapper) runShellCommands(commands []string, commandsType, comman // creating command rCommand := newShellCommand(shellCommand, nil, env, r.getShell(), r.dryRun, r.sigChan, r.setPID) // stdout are stderr are coming from the default terminal (in case they're redirected) - rCommand.stdout = term.GetOutput() - rCommand.stderr = term.GetErrorOutput() - term.FlushAllOutput() + rCommand.stdout = r.ctx.terminal.Stdout() + rCommand.stderr = r.ctx.terminal.Stderr() + r.ctx.terminal.FlushAllOutput() _, stderr, err := runShellCommand(rCommand) if err != nil { err = fmt.Errorf("%s on profile '%s': %w", commandsType, r.profile.Name, err) @@ -659,9 +658,9 @@ func (r *resticWrapper) runFinalShellCommands(command string, fail error) { // creating command rCommand := newShellCommand(cmd, nil, env, r.getShell(), r.dryRun, r.sigChan, r.setPID) // stdout are stderr are coming from the default terminal (in case they're redirected) - rCommand.stdout = term.GetOutput() - rCommand.stderr = term.GetErrorOutput() - term.FlushAllOutput() + rCommand.stdout = r.ctx.terminal.Stdout() + rCommand.stderr = r.ctx.terminal.Stderr() + r.ctx.terminal.FlushAllOutput() _, _, err := runShellCommand(rCommand) if err != nil { clog.Errorf("run-finally command %d/%d failed ('%s' on profile '%s'): %s", @@ -694,7 +693,7 @@ func (r *resticWrapper) sendFinally(monitoring config.SendMonitoringSections, co func (r *resticWrapper) sendMonitoring(sections []config.SendMonitoringSection, command, sendType string, err error) { for i, section := range sections { clog.Debugf("starting %q from %s %d/%d", sendType, command, i+1, len(sections)) - term.FlushAllOutput() + r.ctx.terminal.FlushAllOutput() err := r.sender.Send(section, r.getContextWithError(err), r.profile.GetEnvironment(true)) if err != nil { clog.Warningf("%q returned an error: %s", sendType, err.Error()) diff --git a/wrapper_streamsource.go b/wrapper_streamsource.go index 205abad4..8b4c027a 100644 --- a/wrapper_streamsource.go +++ b/wrapper_streamsource.go @@ -14,7 +14,6 @@ import ( "time" "github.com/creativeprojects/clog" - "github.com/creativeprojects/resticprofile/term" ) func (r *resticWrapper) prepareStreamSource() (io.ReadCloser, error) { @@ -77,7 +76,7 @@ func (r *resticWrapper) prepareCommandStreamSource() (io.ReadCloser, error) { clog.Debugf("starting 'stdin-command' command %d/%d: %s", i+1, len(r.profile.Backup.StdinCommand), sourceCommand) rCommand := newShellCommand(sourceCommand, nil, env, r.getShell(), r.dryRun, commandSignals, nil) rCommand.stdout = bufferedWriter - rCommand.stderr = term.GetErrorOutput() + rCommand.stderr = r.ctx.terminal.Stderr() _, stderr, err := runShellCommand(rCommand) if err != nil { diff --git a/wrapper_test.go b/wrapper_test.go index fd4a8793..ad60a36a 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -182,10 +182,11 @@ func TestFilteredArgumentsRegression(t *testing.T) { profile, err := cfg.GetProfile("default") require.NoError(t, err) wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - binary: "restic", - profile: profile, - command: "test", + flags: commandLineFlags{dryRun: true}, + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), }) for command, commandline := range test.expected { @@ -203,9 +204,10 @@ func TestGetEmptyEnvironment(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "restic", - profile: profile, - command: "test", + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) env := wrapper.getEnvironment(false) @@ -221,9 +223,10 @@ func TestGetSingleEnvironment(t *testing.T) { } profile.ResolveConfiguration() ctx := &Context{ - binary: "restic", - profile: profile, - command: "test", + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) env := wrapper.getEnvironment(false) @@ -240,9 +243,10 @@ func TestGetMultipleEnvironment(t *testing.T) { } profile.ResolveConfiguration() ctx := &Context{ - binary: "restic", - profile: profile, - command: "test", + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) @@ -267,9 +271,10 @@ func TestPreProfileScriptFail(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunBefore = []string{"exit 1"} // this should both work on unix shell and windows batch ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -282,9 +287,10 @@ func TestPostProfileScriptFail(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunAfter = []string{"exit 1"} // this should both work on unix shell and windows batch ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -296,9 +302,10 @@ func TestRunEchoProfile(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -313,9 +320,10 @@ func TestPostProfileAfterFail(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunAfter = []string{"echo failed > " + testFile} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -332,9 +340,10 @@ func TestPostFailProfile(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo failed > " + testFile} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -368,9 +377,10 @@ func TestFinallyProfile(t *testing.T) { t.Run("backup-before-profile", func(t *testing.T) { newProfile() ctx := &Context{ - binary: "echo", - profile: profile, - command: "backup", + binary: "echo", + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -382,9 +392,10 @@ func TestFinallyProfile(t *testing.T) { newProfile() profile.RunFinally = nil ctx := &Context{ - binary: "echo", - profile: profile, - command: "backup", + binary: "echo", + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -395,9 +406,10 @@ func TestFinallyProfile(t *testing.T) { t.Run("on-error", func(t *testing.T) { newProfile() ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -407,12 +419,12 @@ func TestFinallyProfile(t *testing.T) { } func Example_runProfile() { - term.SetOutput(os.Stdout) profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -423,13 +435,14 @@ func Example_runProfile() { } func TestRunRedirectOutputOfEchoProfile(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -438,14 +451,15 @@ func TestRunRedirectOutputOfEchoProfile(t *testing.T) { } func TestDryRun(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - binary: "echo", - profile: profile, - command: "test", + flags: commandLineFlags{dryRun: true}, + binary: "echo", + profile: profile, + command: "test", + terminal: terminal, }) err := wrapper.runProfile() assert.NoError(t, err) @@ -453,15 +467,16 @@ func TestDryRun(t *testing.T) { } func TestEnvProfileName(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "TestEnvProfileName") profile.RunBefore = []string{"echo profile name = $PROFILE_NAME"} ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -470,15 +485,16 @@ func TestEnvProfileName(t *testing.T) { } func TestEnvProfileCommand(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunBefore = []string{"echo profile command = $PROFILE_COMMAND"} ctx := &Context{ - binary: "echo", - profile: profile, - command: "test-command", + binary: "echo", + profile: profile, + command: "test-command", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -487,15 +503,16 @@ func TestEnvProfileCommand(t *testing.T) { } func TestEnvError(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo error: $ERROR_MESSAGE"} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -504,15 +521,16 @@ func TestEnvError(t *testing.T) { } func TestEnvErrorCommandLine(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo cmd: $ERROR_COMMANDLINE"} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -521,15 +539,16 @@ func TestEnvErrorCommandLine(t *testing.T) { } func TestEnvErrorExitCode(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo exit-code: $ERROR_EXIT_CODE"} ctx := &Context{ - binary: "exit", - profile: profile, - command: "5", + binary: "exit", + profile: profile, + command: "5", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -538,16 +557,17 @@ func TestEnvErrorExitCode(t *testing.T) { } func TestEnvStderr(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo stderr: $ERROR_STDERR"} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "command", - request: Request{arguments: []string{"--stderr", "error_message", "--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "command", + request: Request{arguments: []string{"--stderr", "error_message", "--exit", "1"}}, + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -562,9 +582,10 @@ func TestRunProfileWithSetPIDCallback(t *testing.T) { profile.Lock = filepath.Join(os.TempDir(), fmt.Sprintf("%s%d%d.tmp", "TestRunProfileWithSetPIDCallback", time.Now().UnixNano(), os.Getpid())) t.Logf("lockfile = %s", profile.Lock) ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -576,9 +597,10 @@ func TestInitializeNoError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitialize() @@ -590,10 +612,11 @@ func TestInitializeWithError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitialize() @@ -606,9 +629,10 @@ func TestInitializeCopyNoError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Copy = &config.CopySection{InitializeCopyChunkerParams: maybe.False()} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitializeCopy() @@ -621,10 +645,11 @@ func TestInitializeCopyWithError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Copy = &config.CopySection{InitializeCopyChunkerParams: maybe.False()} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitializeCopy() @@ -636,9 +661,10 @@ func TestCheckNoError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCheck() @@ -650,10 +676,11 @@ func TestCheckWithError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCheck() @@ -665,9 +692,10 @@ func TestRetentionNoError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runRetention() @@ -679,10 +707,11 @@ func TestRetentionWithError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runRetention() @@ -711,10 +740,11 @@ func TestBackupWithStreamSource(t *testing.T) { profile.Backup = &config.BackupSection{} signals := make(chan os.Signal, 1) ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "stdin-test", - sigChan: signals, + binary: mockBinary, + profile: profile, + command: "stdin-test", + sigChan: signals, + terminal: term.NewTerminal(), } wrapper = newResticWrapper(ctx) return @@ -854,9 +884,10 @@ func TestBackupWithSuccess(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -869,10 +900,11 @@ func TestBackupWithError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "1"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -899,12 +931,13 @@ func TestBackupWithResticLockFailureRetried(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - global: global, - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, - sigChan: sigChan, + global: global, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, + sigChan: sigChan, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.lockWait = &lockWait @@ -935,12 +968,13 @@ func TestBackupWithResticLockFailureCancelled(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - global: global, - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, - sigChan: sigChan, + global: global, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, + sigChan: sigChan, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.lockWait = &lockWait @@ -960,10 +994,11 @@ func TestBackupWithNoConfiguration(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "1"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -976,10 +1011,11 @@ func TestBackupWithNoConfigurationButStatusFile(t *testing.T) { profile := config.NewProfile(nil, "name") profile.StatusFile = "status.json" ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "1"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.addProgress(status.NewProgress(profile, status.NewStatus("status.json"))) @@ -993,10 +1029,11 @@ func TestBackupWithWarningAsError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "3"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "3"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -1009,10 +1046,11 @@ func TestBackupWithSupressedWarnings(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Backup = &config.BackupSection{NoErrorOnWarning: true} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "3"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "3"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -1043,9 +1081,10 @@ func TestRunShellCommands(t *testing.T) { t.Run(fmt.Sprintf("run-before '%s'", command), func(t *testing.T) { section.RunBefore = []string{"exit 2"} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: command, + binary: mockBinary, + profile: profile, + command: command, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -1061,9 +1100,10 @@ func TestRunShellCommands(t *testing.T) { t.Run(fmt.Sprintf("run-after '%s'", command), func(t *testing.T) { section.RunAfter = []string{"exit 2"} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: command, + binary: mockBinary, + profile: profile, + command: command, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -1079,8 +1119,8 @@ func TestRunShellCommands(t *testing.T) { } func TestRunStreamErrorHandler(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) errorCommand := `echo "detected error in $PROFILE_COMMAND"` @@ -1088,10 +1128,11 @@ func TestRunStreamErrorHandler(t *testing.T) { profile.Backup = &config.BackupSection{} profile.StreamError = []config.StreamErrorSection{{Pattern: ".+error-line.+", Run: errorCommand}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", - request: Request{arguments: []string{"--stderr", "--error-line--"}}, + binary: mockBinary, + profile: profile, + command: "backup", + request: Request{arguments: []string{"--stderr", "--error-line--"}}, + terminal: terminal, } wrapper := newResticWrapper(ctx) @@ -1107,10 +1148,11 @@ func TestRunStreamErrorHandlerDoesNotBreakCommand(t *testing.T) { profile.Backup = &config.BackupSection{} profile.StreamError = []config.StreamErrorSection{{Pattern: ".+error-line.+", Run: "exit 1"}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", - request: Request{arguments: []string{"--stderr", "--error-line--"}}, + binary: mockBinary, + profile: profile, + command: "backup", + request: Request{arguments: []string{"--stderr", "--error-line--"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) @@ -1125,10 +1167,11 @@ func TestStreamErrorHandlerWithInvalidRegex(t *testing.T) { profile.Backup = &config.BackupSection{} profile.StreamError = []config.StreamErrorSection{{Pattern: "(", Run: "echo pass"}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", - request: Request{arguments: []string{}}, + binary: mockBinary, + profile: profile, + command: "backup", + request: Request{arguments: []string{}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) @@ -1141,9 +1184,10 @@ func TestCanRetryAfterErrorDontFailWhenNoOutputAnalysis(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", + binary: mockBinary, + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) summary := monitor.Summary{} @@ -1164,9 +1208,10 @@ func TestCanRetryAfterRemoteStaleLockFailure(t *testing.T) { profile.Repository = config.NewConfidentialValue("my-repo") profile.ForceLock = true ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", + binary: mockBinary, + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.startTime = time.Now() @@ -1229,9 +1274,10 @@ func TestCanRetryAfterRemoteLockFailure(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Repository = config.NewConfidentialValue("my-repo") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", + binary: mockBinary, + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.startTime = time.Now() @@ -1284,10 +1330,11 @@ func TestCanUseResticLockRetry(t *testing.T) { getWrapper := func() *resticWrapper { wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - binary: "restic", - profile: profile, - command: constants.CommandBackup, + flags: commandLineFlags{dryRun: true}, + binary: "restic", + profile: profile, + command: constants.CommandBackup, + terminal: term.NewTerminal(), }) wrapper.startTime = time.Now() wrapper.global.ResticLockRetryAfter = 1 * time.Minute @@ -1370,23 +1417,26 @@ func TestLocksAndLockWait(t *testing.T) { profile.Lock = filepath.Join(os.TempDir(), fmt.Sprintf("%s%d%d.tmp", "TestLockWait", time.Now().UnixNano(), os.Getpid())) defer os.Remove(profile.Lock) - term.SetOutput(os.Stdout) + terminal := term.NewTerminal() ctx1 := &Context{ - binary: mockBinary, - profile: profile, - command: constants.CommandBackup, - request: Request{arguments: []string{"--sleep", "1500"}}, + binary: mockBinary, + profile: profile, + command: constants.CommandBackup, + request: Request{arguments: []string{"--sleep", "1500"}}, + terminal: terminal, } ctx2 := &Context{ - binary: mockBinary, - profile: profile, - command: constants.CommandBackup, + binary: mockBinary, + profile: profile, + command: constants.CommandBackup, + terminal: terminal, } ctx3 := &Context{ - binary: mockBinary, - profile: profile, - command: constants.CommandBackup, + binary: mockBinary, + profile: profile, + command: constants.CommandBackup, + terminal: terminal, } w1 := newResticWrapper(ctx1) w2 := newResticWrapper(ctx2) @@ -1675,11 +1725,12 @@ func TestRunInitCopyCommand(t *testing.T) { defer clog.SetDefaultLogger(defaultLogger) wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - global: config.NewGlobal(), - binary: "test", - profile: testCase.profile, - command: "copy", + flags: commandLineFlags{dryRun: true}, + global: config.NewGlobal(), + binary: "test", + profile: testCase.profile, + command: "copy", + terminal: term.NewTerminal(), }) // 1. run init command with copy profile err := wrapper.runInitializeCopy() @@ -1703,9 +1754,10 @@ func TestCopyNoSnapshot(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Copy = &config.CopySection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := shell.NewArgs() @@ -1719,9 +1771,10 @@ func TestCopySnapshot(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Copy = &config.CopySection{Snapshots: []string{"snapshot1", "snapshot2"}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := shell.NewArgs() @@ -1737,9 +1790,10 @@ func TestPrepareCommandShouldEscapeBinary(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") ctx := &Context{ - binary: "/full path to/restic", - profile: profile, - command: "backup", + binary: "/full path to/restic", + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := shell.NewArgs() @@ -1752,10 +1806,11 @@ func TestRunUnlockWithCommandLineFlags(t *testing.T) { profile := config.NewProfile(&config.Config{}, "TestProfile") ctx := &Context{ - binary: "restic", - profile: profile, - command: constants.CommandForget, - request: Request{arguments: []string{"some-string", "--some-flag", "-n"}}, + binary: "restic", + profile: profile, + command: constants.CommandForget, + request: Request{arguments: []string{"some-string", "--some-flag", "-n"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := profile.GetCommandFlags(constants.CommandUnlock)