From 343f355f3075c6ff0f37005a51b67868d2af083d Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 14:38:28 +0100 Subject: [PATCH 01/15] feat: add terminal instance instead of global term package --- commands.go | 41 ++-- commands_display.go | 206 +++++------------- commands_display_test.go | 281 +++--------------------- commands_generate.go | 14 +- commands_schedule.go | 9 +- commands_schedule_test.go | 46 ++-- commands_test.go | 66 +++--- config/flag.go | 5 +- config/flag_test.go | 4 +- config/info.go | 3 +- context.go | 12 +- integration_test.go | 17 +- logger.go | 147 +------------ logger_test.go | 4 +- main.go | 31 ++- own_commands.go | 20 +- own_commands_test.go | 14 +- schedule/handler_crond.go | 3 +- schedule/handler_darwin.go | 2 +- schedule/handler_systemd.go | 12 +- schedule/handler_systemd_test.go | 7 +- schedule/handler_windows.go | 3 +- schedule/schedules.go | 30 +-- schedule/schedules_test.go | 22 +- serve.go | 5 +- shell/util.go | 45 ++++ term/default.go | 17 ++ term/nil_reader.go | 7 + term/recorder.go | 39 ++++ term/recording.go | 75 +++++++ term/term.go | 360 +++++++++++++++++++++++-------- term/term_test.go | 107 --------- term/terminal.go | 193 +++++++++++++++++ term/terminal_option.go | 62 ++++++ term/terminal_test.go | 190 ++++++++++++++++ term/test/main.go | 16 ++ update.go | 11 +- update_test.go | 3 +- util/ansi/colors.go | 39 ++++ util/ansi/colors_test.go | 19 ++ util/ansi/escape.go | 36 ++++ util/ansi/escape_test.go | 14 ++ util/ansi/linewriter.go | 110 ++++++++++ util/ansi/linewriter_test.go | 280 ++++++++++++++++++++++++ util/ansi/runes.go | 26 +++ util/ansi/runes_test.go | 40 ++++ util/filewriter.go | 229 ++++++++++++++++++++ util/filewriter_test.go | 258 ++++++++++++++++++++++ util/reader.go | 94 ++++++++ util/reader_test.go | 293 +++++++++++++++++++++++++ util/writer.go | 73 +++++++ util/writer_test.go | 207 ++++++++++++++++++ wrapper_test.go | 55 +++-- 53 files changed, 2950 insertions(+), 952 deletions(-) create mode 100644 shell/util.go create mode 100644 term/default.go create mode 100644 term/nil_reader.go create mode 100644 term/recorder.go create mode 100644 term/recording.go delete mode 100644 term/term_test.go create mode 100644 term/terminal.go create mode 100644 term/terminal_option.go create mode 100644 term/terminal_test.go create mode 100644 term/test/main.go create mode 100644 util/ansi/colors.go create mode 100644 util/ansi/colors_test.go create mode 100644 util/ansi/escape.go create mode 100644 util/ansi/escape_test.go create mode 100644 util/ansi/linewriter.go create mode 100644 util/ansi/linewriter_test.go create mode 100644 util/ansi/runes.go create mode 100644 util/ansi/runes_test.go create mode 100644 util/filewriter.go create mode 100644 util/filewriter_test.go create mode 100644 util/reader.go create mode 100644 util/reader_test.go create mode 100644 util/writer.go create mode 100644 util/writer_test.go diff --git a/commands.go b/commands.go index 9b101381..5e638b20 100644 --- a/commands.go +++ b/commands.go @@ -22,7 +22,6 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/platform" "github.com/creativeprojects/resticprofile/remote" - "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/win" "github.com/distatus/battery" ) @@ -190,11 +189,11 @@ func getOwnCommands() []ownCommand { } } -func panicCommand(_ io.Writer, _ commandContext) error { +func panicCommand(_ commandContext) error { panic("you asked for it") } -func completeCommand(output io.Writer, ctx commandContext) error { +func completeCommand(ctx commandContext) error { args := ctx.request.arguments requester := "unknown" requesterVersion := 0 @@ -226,13 +225,13 @@ func completeCommand(output io.Writer, ctx commandContext) error { completions := NewCompleter(ctx.ownCommands.All(), DefaultFlagsLoader, includeDescription).Complete(args) if len(completions) > 0 { for _, completion := range completions { - fmt.Fprintln(output, completion) + ctx.terminal.Println(completion) } } return nil } -func showProfileOrGroup(output io.Writer, ctx commandContext) error { +func showProfileOrGroup(ctx commandContext) error { c := ctx.config flags := ctx.flags @@ -270,21 +269,21 @@ func showProfileOrGroup(output io.Writer, ctx commandContext) error { } // Show global - err = config.ShowStruct(output, global, constants.SectionConfigurationGlobal) + err = config.ShowStruct(ctx.terminal, global, constants.SectionConfigurationGlobal) if err != nil { clog.Errorf("cannot show global section: %s", err.Error()) } - _, _ = fmt.Fprintln(output) + _, _ = ctx.terminal.Println() // Show profile or group - err = config.ShowStruct(output, profileOrGroup, profileOrGroup.Kind()+" "+flags.name) + err = config.ShowStruct(ctx.terminal, profileOrGroup, profileOrGroup.Kind()+" "+flags.name) if err != nil { clog.Errorf("cannot show profile or group '%s': %s", flags.name, err.Error()) } - _, _ = fmt.Fprintln(output) + _, _ = ctx.terminal.Println() // Show schedules - showSchedules(output, slices.Collect(maps.Values(profileOrGroup.Schedules()))) + showSchedules(ctx.terminal, slices.Collect(maps.Values(profileOrGroup.Schedules()))) if profile, ok := profileOrGroup.(*config.Profile); ok { // Show deprecation notice @@ -309,7 +308,7 @@ func showSchedules(output io.Writer, schedules []*config.Schedule) { } // randomKey simply display a base64'd random key to the console -func randomKey(output io.Writer, ctx commandContext) error { +func randomKey(ctx commandContext) error { var err error flags := ctx.flags size := uint64(1024) @@ -329,19 +328,19 @@ func randomKey(output io.Writer, ctx commandContext) error { if err != nil { return err } - encoder := base64.NewEncoder(base64.StdEncoding, output) + encoder := base64.NewEncoder(base64.StdEncoding, ctx.terminal) _, err = encoder.Write(buffer) encoder.Close() - fmt.Fprintln(output, "") + ctx.terminal.Println() return err } -func testElevationCommand(_ io.Writer, ctx commandContext) error { +func testElevationCommand(ctx commandContext) error { if ctx.flags.isChild { client := remote.NewClient(ctx.flags.parentPort) - term.Print("first line", "\n") - term.Println("second", "one") - term.Printf("value = %d\n", 11) + ctx.terminal.Print("first line", "\n") + ctx.terminal.Println("second", "one") + ctx.terminal.Printf("value = %d\n", 11) err := client.Done() if err != nil { return err @@ -395,7 +394,7 @@ func elevated() error { return nil } -func batteryCommand(stdout io.Writer, _ commandContext) error { +func batteryCommand(ctx commandContext) error { all, err := batt.Batteries() if err != nil { if errors.Is(err, battery.ErrFatal{}) { @@ -407,13 +406,13 @@ func batteryCommand(stdout io.Writer, _ commandContext) error { clog.Info("no battery detected") return nil } - fmt.Fprintln(stdout, "") - w := tabwriter.NewWriter(stdout, 2, 2, 2, ' ', tabwriter.AlignRight) + ctx.terminal.Println() + w := tabwriter.NewWriter(ctx.terminal, 2, 2, 2, ' ', tabwriter.AlignRight) fmt.Fprintf(w, "Battery\tStatus\tCurrent capacity\tFull capacity\tDesign Capacity\tCharge Rate\tVoltage\t\n") for index, batt := range all { fmt.Fprintf(w, "#%d\t%s\t%.2f mWh\t%.2f mWh\t%.2f mWh\t%.2f mWh\t%.2f V\t\n", index, batt.State, batt.Current, batt.Full, batt.Design, batt.ChargeRate, batt.Voltage) } w.Flush() - fmt.Fprintln(stdout, "") + ctx.terminal.Println() return nil } diff --git a/commands_display.go b/commands_display.go index 5737cdf7..f11ccbc9 100644 --- a/commands_display.go +++ b/commands_display.go @@ -17,30 +17,18 @@ import ( "github.com/creativeprojects/resticprofile/shell" "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/ansi" "github.com/creativeprojects/resticprofile/util/collect" - "github.com/fatih/color" - colorable "github.com/mattn/go-colorable" ) -var ( - ansiBold = color.New(color.Bold).SprintFunc() - ansiCyan = color.New(color.FgCyan).SprintFunc() - ansiYellow = color.New(color.FgYellow).SprintFunc() -) - -func displayWriter(output io.Writer, flags commandLineFlags) (out func(args ...any) io.Writer, closer func()) { - if term.GetOutput() == output { - output = term.GetColorableOutput() - - if width, _ := term.OsStdoutTerminalSize(); width > 10 { - output = newLineLengthWriter(output, width) +func displayWriter(terminal *term.Terminal) (out func(args ...any) io.Writer, closer func()) { + var output io.Writer = terminal + if terminal.StdoutIsTerminal() { + if width, _ := terminal.Size(); width > 10 { + output = ansi.NewLineLengthWriter(terminal, width) } } - if flags.noAnsi { - output = colorable.NewNonColorable(output) - } - w := tabwriter.NewWriter(output, 0, 0, 2, ' ', 0) out = func(args ...any) io.Writer { @@ -70,12 +58,12 @@ func getCommonUsageHelpLine(commandName string, withProfile bool) string { } return fmt.Sprintf( "%s [resticprofile flags] %s%s", - ansiBold("resticprofile"), profile, ansiBold(commandName), + ansi.Bold("resticprofile"), profile, ansi.Bold(commandName), ) } -func displayOwnCommands(output io.Writer, ctx commandContext) { - out, closer := displayWriter(output, ctx.flags) +func displayOwnCommands(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() for _, command := range ctx.ownCommands.commands { @@ -87,8 +75,8 @@ func displayOwnCommands(output io.Writer, ctx commandContext) { } } -func displayOwnCommandHelp(output io.Writer, commandName string, ctx commandContext) { - out, closer := displayWriter(output, ctx.flags) +func displayOwnCommandHelp(ctx commandContext, commandName string) { + out, closer := displayWriter(ctx.terminal) defer closer() var command *ownCommand @@ -131,8 +119,8 @@ func displayOwnCommandHelp(output io.Writer, commandName string, ctx commandCont } } -func displayCommonUsageHelp(output io.Writer, ctx commandContext) { - out, closer := displayWriter(output, ctx.flags) +func displayCommonUsageHelp(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() out("resticprofile is a configuration profiles manager for backup profiles and ") @@ -142,26 +130,26 @@ func displayCommonUsageHelp(output io.Writer, ctx commandContext) { out("\t%s [restic flags]\n", getCommonUsageHelpLine("restic-command", true)) out("\t%s [command specific flags]\n", getCommonUsageHelpLine("resticprofile-command", true)) out("\n") - out(ansiBold("resticprofile flags:\n")) + out(ansi.Bold("resticprofile flags:\n")) out(ctx.flags.usagesHelp) out("\n\n") - out(ansiBold("resticprofile own commands:\n")) - displayOwnCommands(out(), ctx) + out(ansi.Bold("resticprofile own commands:\n")) + displayOwnCommands(ctx) out("\n") out("%s at %s\n", - ansiBold("Documentation available"), - ansiBold(ansiCyan("https://creativeprojects.github.io/resticprofile/"))) + ansi.Bold("Documentation available"), + ansi.Bold(ansi.Cyan("https://creativeprojects.github.io/resticprofile/"))) out("\n") } -func displayResticHelp(output io.Writer, configuration *config.Config, flags commandLineFlags, command string) { - out, closer := displayWriter(output, flags) +func displayResticHelp(ctx commandContext, command string) { + out, closer := displayWriter(ctx.terminal) defer closer() resticBinary := "" - if configuration != nil { - if section, err := configuration.GetGlobalSection(); err == nil { + if ctx.config != nil { + if section, err := ctx.config.GetGlobalSection(); err == nil { resticBinary = section.ResticBinary } } @@ -187,11 +175,11 @@ func displayResticHelp(output io.Writer, configuration *config.Config, flags com out("restic binary not found: %s\n", err.Error()) } - if configuration != nil { - out("\nFlags applied by resticprofile (configuration \"%s\"):\n\n", ansiBold(configuration.GetConfigFile())) + if ctx.config != nil { + out("\nFlags applied by resticprofile (configuration \"%s\"):\n\n", ansi.Bold(ctx.config.GetConfigFile())) - if profileNames := configuration.GetProfileNames(); len(profileNames) > 0 { - profiles := configuration.GetProfiles() + if profileNames := ctx.config.GetProfileNames(); len(profileNames) > 0 { + profiles := ctx.config.GetProfiles() sort.Strings(profileNames) unescaper := strings.NewReplacer( `\\`, `^^`, @@ -202,14 +190,14 @@ func displayResticHelp(output io.Writer, configuration *config.Config, flags com ) for _, name := range profileNames { - out("\tprofile \"%s\":", ansiBold(name)) + out("\tprofile \"%s\":", ansi.Bold(name)) profile := profiles[name] cmdFlags := config.GetNonConfidentialArgs(profile, profile.GetCommandFlags(command)) for _, flag := range cmdFlags.GetAll() { if strings.HasPrefix(flag, "-") { out("\n\t\t") } - out("%s\t", ansiCyan(unescaper.Replace(flag))) + out("%s\t", ansi.Cyan(unescaper.Replace(flag))) } out("\n") } @@ -219,12 +207,9 @@ func displayResticHelp(output io.Writer, configuration *config.Config, flags com } } -func displayHelpCommand(output io.Writer, ctx commandContext) error { +func displayHelpCommand(ctx commandContext) error { flags := ctx.flags - out, closer := displayWriter(output, ctx.flags) - defer closer() - if flags.log == "" { clog.GetDefaultLogger().SetHandler(clog.NewDiscardHandler()) // disable log output } @@ -238,23 +223,23 @@ func displayHelpCommand(output io.Writer, ctx commandContext) error { } if helpForCommand == nil { - displayCommonUsageHelp(out("\n"), ctx) + displayCommonUsageHelp(ctx) } else if ctx.ownCommands.Exists(*helpForCommand, true) || ctx.ownCommands.Exists(*helpForCommand, false) { - displayOwnCommandHelp(out("\n"), *helpForCommand, ctx) + displayOwnCommandHelp(ctx, *helpForCommand) } else { - displayResticHelp(out(), ctx.config, flags, *helpForCommand) + displayResticHelp(ctx, *helpForCommand) } return nil } -func displayVersion(output io.Writer, ctx commandContext) error { - out, closer := displayWriter(output, ctx.flags) +func displayVersion(ctx commandContext) error { + out, closer := displayWriter(ctx.terminal) defer closer() - out("resticprofile version %s commit %s\n", ansiBold(version), ansiYellow(commit)) + out("resticprofile version %s commit %s\n", ansi.Bold(version), ansi.Yellow(commit)) // allow for the general verbose flag, or specified after the command arguments := ctx.request.arguments @@ -293,53 +278,53 @@ func displayVersion(output io.Writer, ctx commandContext) error { out("\n") out("\t%s:\n", "go modules") for _, dep := range bi.Deps { - out("\t\t%s\t%s\n", ansiCyan(dep.Path), dep.Version) + out("\t\t%s\t%s\n", ansi.Cyan(dep.Path), dep.Version) } out("\n") } return nil } -func displayProfilesCommand(output io.Writer, ctx commandContext) error { - displayProfiles(output, ctx.config, ctx.flags) - displayGroups(output, ctx.config, ctx.flags) +func displayProfilesCommand(ctx commandContext) error { + displayProfiles(ctx) + displayGroups(ctx) return nil } -func displayProfiles(output io.Writer, configuration *config.Config, flags commandLineFlags) { - out, closer := displayWriter(output, flags) +func displayProfiles(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() - profiles := configuration.GetProfiles() + profiles := ctx.config.GetProfiles() keys := sortedProfileKeys(profiles) if len(profiles) == 0 { out("\nThere's no available profile in the configuration\n") } else { - out("\n%s (name, sections, description):\n", ansiBold("Profiles available")) + out("\n%s (name, sections, description):\n", ansi.Bold("Profiles available")) for _, name := range keys { sections := profiles[name].DefinedCommands() sort.Strings(sections) if len(sections) == 0 { out("\t%s:\t(n/a)\t%s\n", name, profiles[name].Description) } else { - out("\t%s:\t(%s)\t%s\n", name, ansiCyan(strings.Join(sections, ", ")), profiles[name].Description) + out("\t%s:\t(%s)\t%s\n", name, ansi.Cyan(strings.Join(sections, ", ")), profiles[name].Description) } } } out("\n") } -func displayGroups(output io.Writer, configuration *config.Config, flags commandLineFlags) { - out, closer := displayWriter(output, flags) +func displayGroups(ctx commandContext) { + out, closer := displayWriter(ctx.terminal) defer closer() - groups := configuration.GetProfileGroups() + groups := ctx.config.GetProfileGroups() if len(groups) == 0 { return } - out("%s (name, profiles, description):\n", ansiBold("Groups available")) + out("%s (name, profiles, description):\n", ansi.Bold("Groups available")) for name, groupList := range groups { - out("\t%s:\t[%s]\t%s\n", name, ansiCyan(strings.Join(groupList.Profiles, ", ")), groupList.Description) + out("\t%s:\t[%s]\t%s\n", name, ansi.Cyan(strings.Join(groupList.Profiles, ", ")), groupList.Description) } out("\n") } @@ -352,96 +337,3 @@ func sortedProfileKeys(data map[string]*config.Profile) []string { sort.Strings(keys) return keys } - -// lineLengthWriter limits the max line length, adding line breaks ('\n') as needed. -// the writer detects the right most column (consecutive whitespace) and aligns content if possible. -type lineLengthWriter struct { - writer io.Writer - tokens []byte - maxLineLength, lastWhite, breakLength, lineLength int - ansiLength, lastWhiteAnsiLength int -} - -func newLineLengthWriter(writer io.Writer, maxLineLength int) *lineLengthWriter { - return &lineLengthWriter{writer: writer, maxLineLength: maxLineLength} -} - -func (l *lineLengthWriter) Write(p []byte) (n int, err error) { - var written int - inAnsi := false - offset := l.lineLength - lineLength := func() int { return l.lineLength - l.ansiLength } - - if len(l.tokens) == 0 { - l.tokens = []byte{' ', '\n'} - } - - for i := 0; i < len(p); i++ { - l.lineLength++ - ws := p[i] == l.tokens[0] // ' ' - br := p[i] == l.tokens[1] // '\n' - - // don't count ansi control sequences - if inAnsi = inAnsi || p[i] == 0x1b; inAnsi { - inAnsi = p[i] != 'm' - l.ansiLength++ - continue - } - - if !br && lineLength() > l.maxLineLength && l.lastWhite-offset > 0 { - lastWhiteIndex := l.lastWhite - offset - 1 - remainder := i - lastWhiteIndex - - if written, err = l.writer.Write(p[:lastWhiteIndex]); err == nil { - p = p[lastWhiteIndex+1:] - i = remainder - 1 - n += written + 1 - - _, _ = l.writer.Write(l.tokens[1:]) // write break (instead of WS at lastWhiteIndex) - for range l.breakLength { - _, _ = l.writer.Write(l.tokens[0:1]) // fill spaces for alignment - } - - l.lineLength = l.breakLength + remainder - l.lastWhite = l.breakLength - offset = l.breakLength - - l.ansiLength -= l.lastWhiteAnsiLength - l.lastWhiteAnsiLength = 0 - } else { - return - } - } - - if ws { - if l.lastWhite == l.lineLength-1 && lineLength() < l.maxLineLength*2/3 { - l.breakLength = lineLength() - } - l.lastWhite = l.lineLength - l.lastWhiteAnsiLength = l.ansiLength - - } else if br { - if written, err = l.writer.Write(p[:i+1]); err == nil { - p = p[i+1:] - i = -1 - n += written - - l.lineLength = 0 - l.lastWhite = 0 - l.breakLength = 0 - offset = 0 - - l.ansiLength = 0 - l.lastWhiteAnsiLength = 0 - } else { - return - } - } - } - - // write remainder - if written, err = l.writer.Write(p); err == nil { - n += written - } - return -} diff --git a/commands_display_test.go b/commands_display_test.go index ef27e1cf..1c481048 100644 --- a/commands_display_test.go +++ b/commands_display_test.go @@ -2,11 +2,11 @@ package main import ( "bytes" - "fmt" "runtime" "strings" "testing" + "github.com/creativeprojects/resticprofile/term" "github.com/fatih/color" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,290 +21,69 @@ var ansiColor = func() (c *color.Color) { var colored = ansiColor.SprintFunc() func TestDisplayWriter(t *testing.T) { - flags, noAnsiFlags := commandLineFlags{}, commandLineFlags{noAnsi: true} + write := func(v ...any) string { + buffer := new(bytes.Buffer) + recorder, err := term.NewRecorder(buffer) + require.NoError(t, err) - buffer := bytes.Buffer{} - write := func(clf commandLineFlags, v ...any) string { - buffer.Reset() - out, closer := displayWriter(&buffer, clf) + terminal := term.NewTerminal(term.WithStdoutRecorder(recorder), term.WithColors(true)) + out, closer := displayWriter(terminal) out(v...) closer() + recorder.Close() return buffer.String() } t.Run("write-plain", func(t *testing.T) { - actual := write(flags, "hello %s %d") + actual := write("hello %s %d") assert.Equal(t, "hello %s %d", actual) }) t.Run("write-with-format", func(t *testing.T) { - actual := write(flags, "hello %s %02d", "world", 5) + actual := write("hello %s %02d", "world", 5) assert.Equal(t, "hello world 05", actual) }) t.Run("write-tabs", func(t *testing.T) { - actual := write(flags, "col1\tcol2\tcol3\nvalue1\tvalue2\tvalue3") + actual := write("col1\tcol2\tcol3\nvalue1\tvalue2\tvalue3") assert.Equal(t, "col1 col2 col3\nvalue1 value2 value3", actual) }) - t.Run("no-ansi", func(t *testing.T) { - actual := write(flags, colored("test")) + t.Run("ansi", func(t *testing.T) { + actual := write(colored("test")) assert.Equal(t, colored("test"), actual) - - actual = write(noAnsiFlags, colored("test")) - assert.Equal(t, "test", actual) + assert.Contains(t, colored("test"), "\x1b[") }) } -func TestLineLengthWriter(t *testing.T) { - tests := []struct { - input, expected string - chunks, scale int - }{ - // test non-breakable - {input: strings.Repeat("-", 50), expected: strings.Repeat("-", 50), chunks: 15}, - - // test breakable without columns - { - input: strings.Repeat("word ", 20), - expected: "" + - strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + - strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + - strings.Repeat("word ", 4), - chunks: 5, scale: 6, - }, - - // test breakable with ANSI color - { - input: strings.Repeat(colored("word "), 20), - expected: "" + - strings.Repeat(colored("word "), 7) + colored("word\n") + - strings.Repeat(colored("word "), 7) + colored("word\n") + - strings.Repeat(colored("word "), 4), - }, - - // test breakable with 2 columns - { - input: "word word word word " + - strings.Repeat("word ", 20), - expected: "" + - "word word word word " + - "word word word\n" + - strings.Repeat(" word word word\n", 5) + - " word word ", - chunks: 3, scale: 15, - }, - - // test breakable with 2 columns and ANSI color - { - input: colored("word word word word ") + - strings.Repeat(colored("word "), 20), - expected: "" + - colored("word word word word ") + - colored("word ") + colored("word ") + colored("word\n ") + - strings.Repeat(colored("word ")+colored("word ")+colored("word\n "), 5) + - colored("word ") + colored("word "), - }, - - // test real-world content - { - input: ` -Usage of resticprofile: - resticprofile [resticprofile flags] [profile name.][restic command] [restic flags] - resticprofile [resticprofile flags] [profile name.][resticprofile command] [command specific flags] - -resticprofile flags: - -c, --config string configuration file (default "profiles") - --dry-run display the restic commands instead of running them - -f, --format string file format of the configuration (default is to use the file extension) - -h, --help display this help - --lock-wait duration wait up to duration to acquire a lock (syntax "1h5m30s") - -l, --log string logs to a target instead of the console - -n, --name string profile name (default "default") - --no-ansi disable ansi control characters (disable console colouring) - --no-lock skip profile lock file - --no-prio don't set any priority on load: used when started from a service that has already set the priority - -q, --quiet display only warnings and errors - --theme string console colouring theme (dark, light, none) (default "light") - --trace display even more debugging information - -v, --verbose display some debugging information - -w, --wait wait at the end until the user presses the enter key - - -resticprofile own commands: - help display help (run in verbose mode for detailed information) - version display version (run in verbose mode for detailed information) - self-update update to latest resticprofile (use -q/--quiet flag to update without confirmation) - profiles display profile names from the configuration file - show show all the details of the current profile - schedule schedule jobs from a profile (use --all flag to schedule all jobs of all profiles) - unschedule remove scheduled jobs of a profile (use --all flag to unschedule all profiles) - status display the status of scheduled jobs (use --all flag for all profiles) - generate generate resources (--random-key [size], --bash-completion & --zsh-completion) - -Documentation available at https://creativeprojects.github.io/resticprofile/ -`, - expected: ` -Usage of resticprofile: - resticprofile [resticprofile flags] - [profile name.][restic command] - [restic flags] - resticprofile [resticprofile flags] - [profile name.][resticprofile - command] [command specific flags] +func TestDisplayWriterNoColors(t *testing.T) { + write := func(v ...any) string { + buffer := new(bytes.Buffer) + recorder, err := term.NewRecorder(buffer) + require.NoError(t, err) -resticprofile flags: - -c, --config string - configuration - file (default - "profiles") - --dry-run display - the restic - commands - instead of - running them - -f, --format string file - format of the - configuration - (default is to - use the file - extension) - -h, --help display - this help - --lock-wait duration wait up to - duration to acquire a lock - (syntax "1h5m30s") - -l, --log string logs to a - target instead - of the console - -n, --name string profile - name (default - "default") - --no-ansi disable - ansi control - characters - (disable - console - colouring) - --no-lock skip - profile lock - file - --no-prio don't set - any priority - on load: used - when started - from a service - that has - already set - the priority - -q, --quiet display - only warnings - and errors - --theme string console - colouring - theme (dark, - light, none) - (default - "light") - --trace display - even more - debugging - information - -v, --verbose display - some debugging - information - -w, --wait wait at - the end until - the user - presses the - enter key - - -resticprofile own commands: - help display help (run in - verbose mode for - detailed information) - version display version (run - in verbose mode for - detailed information) - self-update update to latest - resticprofile (use - -q/--quiet flag to - update without - confirmation) - profiles display profile names - from the configuration - file - show show all the details - of the current profile - schedule schedule jobs from a - profile (use --all - flag to schedule all - jobs of all profiles) - unschedule remove scheduled jobs - of a profile (use - --all flag to - unschedule all - profiles) - status display the status of - scheduled jobs (use - --all flag for all - profiles) - generate generate resources - (--random-key [size], - --bash-completion & - --zsh-completion) - -Documentation available at -https://creativeprojects.github.io/resticprofile/ -`, - }, - } - - for i, test := range tests { - if test.scale == 0 { - test.scale = 1 - } - for chunkSize := 0; chunkSize <= test.chunks; chunkSize++ { - t.Run(fmt.Sprintf("%d-%d", i, chunkSize), func(t *testing.T) { - buffer := bytes.Buffer{} - writer := newLineLengthWriter(&buffer, 40) - input := []byte(test.input) - - var ( - n int - err error - ) - switch chunkSize { - case 0: - n, err = writer.Write(input) - assert.Equal(t, len(input), n) - default: - for len(input) > 0 && err == nil { - length := min(test.scale*chunkSize, len(input)) - n, err = writer.Write(input[:length]) - assert.Equal(t, length, n) - input = input[length:] - } - } - - assert.Nil(t, err) - assert.Equal(t, test.expected, buffer.String()) - }) - } + terminal := term.NewTerminal(term.WithStdoutRecorder(recorder), term.WithColors(false)) + out, closer := displayWriter(terminal) + out(v...) + closer() + recorder.Close() + return buffer.String() } + actual := write(colored("test")) + assert.Equal(t, "test", actual) + assert.Contains(t, colored("test"), "\x1b[") } func TestDisplayVersionVerbose1(t *testing.T) { buffer := &bytes.Buffer{} - err := displayVersion(buffer, commandContext{Context: Context{flags: commandLineFlags{verbose: true}}}) + err := displayVersion(commandContext{Context: Context{terminal: term.NewTerminal(term.WithStdout(buffer)), flags: commandLineFlags{verbose: true}}}) require.NoError(t, err) assert.True(t, strings.Contains(buffer.String(), runtime.GOOS)) } func TestDisplayVersionVerbose2(t *testing.T) { buffer := &bytes.Buffer{} - err := displayVersion(buffer, commandContext{Context: Context{request: Request{arguments: []string{"-v"}}}}) + err := displayVersion(commandContext{Context: Context{terminal: term.NewTerminal(term.WithStdout(buffer)), request: Request{arguments: []string{"-v"}}}}) require.NoError(t, err) assert.True(t, strings.Contains(buffer.String(), runtime.GOOS)) } diff --git a/commands_generate.go b/commands_generate.go index 00b47907..9b8ed085 100644 --- a/commands_generate.go +++ b/commands_generate.go @@ -29,7 +29,7 @@ var zshCompletionScript string //go:embed contrib/completion/fish-completion.fish var fishCompletionScript string -func generateCommand(output io.Writer, ctx commandContext) (err error) { +func generateCommand(ctx commandContext) (err error) { args := ctx.request.arguments // enforce no-log logger := clog.GetDefaultLogger() @@ -37,18 +37,18 @@ func generateCommand(output io.Writer, ctx commandContext) (err error) { logger.SetHandler(clog.NewDiscardHandler()) if slices.Contains(args, "--bash-completion") { - _, err = fmt.Fprintln(output, bashCompletionScript) + _, err = ctx.terminal.Println(bashCompletionScript) } else if slices.Contains(args, "--config-reference") { - err = generateConfigReference(output, args[slices.Index(args, "--config-reference")+1:]) + err = generateConfigReference(ctx.terminal, args[slices.Index(args, "--config-reference")+1:]) } else if slices.Contains(args, "--json-schema") { - err = generateJsonSchema(output, args[slices.Index(args, "--json-schema")+1:]) + err = generateJsonSchema(ctx.terminal, args[slices.Index(args, "--json-schema")+1:]) } else if slices.Contains(args, "--random-key") { ctx.flags.resticArgs = args[slices.Index(args, "--random-key"):] - err = randomKey(output, ctx) + err = randomKey(ctx) } else if slices.Contains(args, "--zsh-completion") { - _, err = fmt.Fprintln(output, zshCompletionScript) + _, err = ctx.terminal.Println(zshCompletionScript) } else if slices.Contains(args, "--fish-completion") { - _, err = fmt.Fprintln(output, fishCompletionScript) + _, err = ctx.terminal.Println(fishCompletionScript) } else { err = fmt.Errorf("nothing to generate for: %s", strings.Join(args, ", ")) } diff --git a/commands_schedule.go b/commands_schedule.go index 552ea3da..fbe2d0ec 100644 --- a/commands_schedule.go +++ b/commands_schedule.go @@ -3,7 +3,6 @@ package main import ( "errors" "fmt" - "io" "maps" "slices" "strings" @@ -20,7 +19,7 @@ const ( ) // createSchedule command -func createSchedule(_ io.Writer, ctx commandContext) error { +func createSchedule(ctx commandContext) error { c := ctx.config request := ctx.request args := ctx.request.arguments @@ -76,7 +75,7 @@ func createSchedule(_ io.Writer, ctx commandContext) error { return nil } -func removeSchedule(_ io.Writer, ctx commandContext) error { +func removeSchedule(ctx commandContext) error { var err error c := ctx.config request := ctx.request @@ -116,7 +115,7 @@ func removeSchedule(_ io.Writer, ctx commandContext) error { return nil } -func statusSchedule(w io.Writer, ctx commandContext) error { +func statusSchedule(ctx commandContext) error { c := ctx.config request := ctx.request args := ctx.request.arguments @@ -345,7 +344,7 @@ func prepareScheduledProfile(ctx *Context) { } } -func runSchedule(_ io.Writer, cmdCtx commandContext) error { +func runSchedule(cmdCtx commandContext) error { err := startProfileOrGroup(&cmdCtx.Context, runProfile) if err != nil { return err diff --git a/commands_schedule_test.go b/commands_schedule_test.go index 0550a798..4d591407 100644 --- a/commands_schedule_test.go +++ b/commands_schedule_test.go @@ -108,6 +108,10 @@ func TestStatusScheduleIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -115,13 +119,11 @@ func TestStatusScheduleIntegrationUsingCrontab(t *testing.T) { request: Request{ profile: tc.profileName, }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) if tc.err != nil { assert.ErrorIs(t, err, tc.err) return @@ -157,6 +159,10 @@ func TestRemoveScheduleIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -164,14 +170,12 @@ func TestRemoveScheduleIntegrationUsingCrontab(t *testing.T) { request: Request{ profile: "profile", }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) // this should remove the 2 lines that resemble a resticprofile schedule - err = removeSchedule(output, ctx) + err = removeSchedule(ctx) require.NoError(t, err) result, err := os.ReadFile(crontab) @@ -179,7 +183,7 @@ func TestRemoveScheduleIntegrationUsingCrontab(t *testing.T) { fmt.Println(string(result)) // but one line in error should be left in the crontab - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) require.ErrorIs(t, err, crond.ErrEntryNoMatch) } @@ -202,6 +206,10 @@ func TestCreateScheduleIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -209,14 +217,12 @@ func TestCreateScheduleIntegrationUsingCrontab(t *testing.T) { request: Request{ profile: "profile-schedule-struct", }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) // create one schedule - err = createSchedule(output, ctx) + err = createSchedule(ctx) require.NoError(t, err) result, err := os.ReadFile(crontab) @@ -225,7 +231,7 @@ func TestCreateScheduleIntegrationUsingCrontab(t *testing.T) { fmt.Println(string(result)) output.Reset() - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) require.NoError(t, err) // verify the new schedule was added @@ -258,6 +264,10 @@ func TestCreateScheduleOverwriteExistingIntegrationUsingCrontab(t *testing.T) { require.NoError(t, err) require.NotNil(t, global) + output := &bytes.Buffer{} + terminal := term.Set(term.NewTerminal(term.WithStdout(output))) + defer term.Set(nil) + ctx := commandContext{ Context: Context{ config: cfg, @@ -265,14 +275,12 @@ func TestCreateScheduleOverwriteExistingIntegrationUsingCrontab(t *testing.T) { request: Request{ arguments: []string{"--all", "--reload", "--no-start"}, }, + terminal: terminal, }, } - output := &bytes.Buffer{} - term.SetOutput(output) - defer term.SetOutput(os.Stdout) // create (or update) two schedules - err = createSchedule(output, ctx) + err = createSchedule(ctx) require.NoError(t, err) result, err := os.ReadFile(crontab) @@ -282,7 +290,7 @@ func TestCreateScheduleOverwriteExistingIntegrationUsingCrontab(t *testing.T) { fmt.Println(string(result)) output.Reset() - err = statusSchedule(output, ctx) + err = statusSchedule(ctx) require.NoError(t, err) // verify the schedule was replaced diff --git a/commands_test.go b/commands_test.go index 38a6ba57..f1c8d2a7 100644 --- a/commands_test.go +++ b/commands_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "os" "sort" "strings" "testing" @@ -12,6 +11,7 @@ import ( "github.com/creativeprojects/resticprofile/config" "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/schedule" + "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util/collect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,29 +19,31 @@ import ( func TestPanicCommand(t *testing.T) { assert.Panics(t, func() { - _ = panicCommand(nil, commandContext{}) + _ = panicCommand(commandContext{}) }) } func TestRandomKeyOfInvalidSize(t *testing.T) { - assert.Error(t, randomKey(os.Stdout, commandContext{ + assert.Error(t, randomKey(commandContext{ Context: Context{ - flags: commandLineFlags{resticArgs: []string{"restic", "size"}}, + flags: commandLineFlags{resticArgs: []string{"restic", "size"}}, + terminal: term.NewTerminal(), }, })) } func TestRandomKeyOfZeroSize(t *testing.T) { - assert.Error(t, randomKey(os.Stdout, commandContext{ + assert.Error(t, randomKey(commandContext{ Context: Context{ - flags: commandLineFlags{resticArgs: []string{"restic", "0"}}, + flags: commandLineFlags{resticArgs: []string{"restic", "0"}}, + terminal: term.NewTerminal(), }, })) } func TestRandomKey(t *testing.T) { // doesn't look like much, but it's testing the random generator is not throwing an error - assert.NoError(t, randomKey(os.Stdout, commandContext{})) + assert.NoError(t, randomKey(commandContext{Context: Context{terminal: term.NewTerminal()}})) } func TestRemovableSchedules(t *testing.T) { @@ -199,9 +201,12 @@ func TestCompleteCall(t *testing.T) { for _, test := range testTable { t.Run(strings.Join(test.args, " "), func(t *testing.T) { buffer := &strings.Builder{} - assert.Nil(t, completeCommand(buffer, commandContext{ + assert.Nil(t, completeCommand(commandContext{ ownCommands: ownCommands, - Context: Context{request: Request{arguments: test.args}}, + Context: Context{ + request: Request{arguments: test.args}, + terminal: term.NewTerminal(term.WithStdout(buffer)), + }, })) assert.Equal(t, test.expected, buffer.String()) }) @@ -213,33 +218,38 @@ func TestGenerateCommand(t *testing.T) { contextWithArguments := func(args []string) commandContext { t.Helper() - return commandContext{Context: Context{request: Request{arguments: args}}} + return commandContext{ + Context: Context{ + request: Request{arguments: args}, + terminal: term.NewTerminal(term.WithStdout(buffer)), + }, + } } t.Run("--bash-completion", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--bash-completion"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--bash-completion"}))) assert.Equal(t, strings.TrimSpace(bashCompletionScript), strings.TrimSpace(buffer.String())) assert.Contains(t, bashCompletionScript, "#!/usr/bin/env bash") }) t.Run("--zsh-completion", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--zsh-completion"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--zsh-completion"}))) assert.Equal(t, strings.TrimSpace(zshCompletionScript), strings.TrimSpace(buffer.String())) assert.Contains(t, zshCompletionScript, "#!/usr/bin/env zsh") }) t.Run("--fish-completion", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--fish-completion"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--fish-completion"}))) assert.Equal(t, strings.TrimSpace(fishCompletionScript), strings.TrimSpace(buffer.String())) assert.Contains(t, fishCompletionScript, "#!/usr/bin/env fish") }) t.Run("--config-reference", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--config-reference", "--to", t.TempDir()}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--config-reference", "--to", t.TempDir()}))) ref := buffer.String() assert.Contains(t, ref, "generating reference.gomd") assert.Contains(t, ref, "generating profile section") @@ -248,7 +258,7 @@ func TestGenerateCommand(t *testing.T) { t.Run("--json-schema global", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "global"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "global"}))) ref := buffer.String() assert.Contains(t, ref, `"$schema"`) assert.Contains(t, ref, "/jsonschema/config-1.json") @@ -262,24 +272,24 @@ func TestGenerateCommand(t *testing.T) { t.Run("--json-schema no-option", func(t *testing.T) { buffer.Reset() - assert.Error(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema"}))) + assert.Error(t, generateCommand(contextWithArguments([]string{"--json-schema"}))) }) t.Run("--json-schema invalid-option", func(t *testing.T) { buffer.Reset() - assert.Error(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "_invalid_"}))) + assert.Error(t, generateCommand(contextWithArguments([]string{"--json-schema", "_invalid_"}))) }) t.Run("--json-schema v1", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "v1"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "v1"}))) ref := buffer.String() assert.Contains(t, ref, "/jsonschema/config-1.json") }) t.Run("--json-schema v2", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "v2"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "v2"}))) ref := buffer.String() assert.Contains(t, ref, "\"profiles\":") assert.Contains(t, ref, "/jsonschema/config-2.json") @@ -287,14 +297,14 @@ func TestGenerateCommand(t *testing.T) { t.Run("--json-schema --version 0.13 v1", func(t *testing.T) { buffer.Reset() - assert.NoError(t, generateCommand(buffer, contextWithArguments([]string{"--json-schema", "--version", "0.13", "v1"}))) + assert.NoError(t, generateCommand(contextWithArguments([]string{"--json-schema", "--version", "0.13", "v1"}))) ref := buffer.String() assert.Contains(t, ref, "/jsonschema/config-1-restic-0-13.json") }) t.Run("--random-key", func(t *testing.T) { buffer.Reset() - assert.Nil(t, generateCommand(buffer, contextWithArguments([]string{"--random-key", "512"}))) + assert.Nil(t, generateCommand(contextWithArguments([]string{"--random-key", "512"}))) assert.Equal(t, 684, len(strings.TrimSpace(buffer.String()))) }) @@ -303,7 +313,7 @@ func TestGenerateCommand(t *testing.T) { opts := []string{"", "invalid", "--unknown"} for _, option := range opts { buffer.Reset() - err := generateCommand(buffer, contextWithArguments([]string{option})) + err := generateCommand(contextWithArguments([]string{option})) assert.EqualError(t, err, fmt.Sprintf("nothing to generate for: %s", option)) assert.Equal(t, 0, buffer.Len()) } @@ -348,7 +358,7 @@ func TestCreateScheduleWhenNoneAvailable(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = createSchedule(nil, commandContext{ + err = createSchedule(commandContext{ Context: Context{ config: cfg, flags: commandLineFlags{ @@ -366,7 +376,7 @@ func TestCreateScheduleAll(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = createSchedule(nil, commandContext{ + err = createSchedule(commandContext{ Context: Context{ config: cfg, request: Request{ @@ -426,7 +436,7 @@ func TestRunScheduleNoScheduleName(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = runSchedule(nil, commandContext{ + err = runSchedule(commandContext{ Context: Context{ config: cfg, flags: commandLineFlags{ @@ -443,7 +453,7 @@ func TestRunScheduleWrongScheduleName(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = runSchedule(nil, commandContext{ + err = runSchedule(commandContext{ Context: Context{ request: Request{arguments: []string{"wrong"}}, config: cfg, @@ -461,7 +471,7 @@ func TestRunScheduleProfileUnknown(t *testing.T) { cfg, err := config.Load(bytes.NewBufferString("[default]"), "toml") assert.NoError(t, err) - err = runSchedule(nil, commandContext{ + err = runSchedule(commandContext{ Context: Context{ request: Request{arguments: []string{"backup@profile"}}, config: cfg, @@ -472,6 +482,6 @@ func TestRunScheduleProfileUnknown(t *testing.T) { func TestBatteryCommand(t *testing.T) { buffer := &bytes.Buffer{} - err := batteryCommand(buffer, commandContext{}) + err := batteryCommand(commandContext{Context: Context{terminal: term.NewTerminal(term.WithStdout(buffer))}}) require.NoError(t, err) } diff --git a/config/flag.go b/config/flag.go index 35ab236c..b4c7bcd6 100644 --- a/config/flag.go +++ b/config/flag.go @@ -232,7 +232,10 @@ func stringify(value reflect.Value, onlySimplyValues bool) ([]string, bool) { sort.Strings(flatMap) return flatMap, len(flatMap) > 0 - case reflect.Interface: + case reflect.Interface, reflect.Pointer: + if value.IsNil() { + return emptyStringArray, false + } return stringify(value.Elem(), onlySimplyValues) default: diff --git a/config/flag_test.go b/config/flag_test.go index ac126f47..8b60c885 100644 --- a/config/flag_test.go +++ b/config/flag_test.go @@ -10,11 +10,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestPointerValueShouldReturnErrorMessage(t *testing.T) { +func TestPointerValueShouldReturnValue(t *testing.T) { concrete := "test" value := &concrete argValue, _ := stringifyValueOf(value) - assert.Equal(t, []string{"ERROR: unexpected type ptr"}, argValue) + assert.Equal(t, []string{"test"}, argValue) } func TestNilValueFlag(t *testing.T) { diff --git a/config/info.go b/config/info.go index 7c85d44c..880a1a02 100644 --- a/config/info.go +++ b/config/info.go @@ -672,11 +672,10 @@ func NewProfileInfoForRestic(resticVersion string, withDefaultOptions bool) Prof // Building initial set including generic sections (from data model) profileSet := propertySetFromType(infoTypes.profile) { - genericSection := propertySetFromType(infoTypes.genericSection) for _, name := range infoTypes.genericSectionNames { pi := new(propertyInfo) pi.nested = &namedPropertySet{ - propertySet: genericSection, + propertySet: propertySetFromType(infoTypes.genericSection), name: name, } profileSet.properties[name] = pi diff --git a/context.go b/context.go index 7a968951..01f1efd2 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ import ( "github.com/creativeprojects/resticprofile/config" "github.com/creativeprojects/resticprofile/constants" + "github.com/creativeprojects/resticprofile/term" ) type Request struct { @@ -17,8 +18,8 @@ type Request struct { } // Context for running a profile command. -// Not everything is always available, -// but any information should be added to the context as soon as known. +// All fields are not always populated at the same time, +// as the context is built up step by step when running a profile command. type Context struct { request Request flags commandLineFlags @@ -35,6 +36,7 @@ type Context struct { noLock bool // skip profile lock file lockWait time.Duration // wait up to duration to acquire a lock legacyArgs bool // I'm not even sure it's been used by anyone? + terminal *term.Terminal } func CreateContext(flags commandLineFlags, global *config.Global, cfg *config.Config, ownCommands *OwnCommands) (*Context, error) { @@ -135,6 +137,12 @@ func (c *Context) WithProfile(profileName string) *Context { return newContext } +func (c *Context) WithTerminal(terminal *term.Terminal) *Context { + newContext := c.clone() + newContext.terminal = terminal + return newContext +} + func (c *Context) clone() *Context { clone := *c return &clone diff --git a/integration_test.go b/integration_test.go index 5e8aea48..ad567d4d 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1,8 +1,6 @@ package main import ( - "bytes" - "os" "path/filepath" "strings" "testing" @@ -131,12 +129,12 @@ func TestFromConfigFileToCommandLine(t *testing.T) { command: fixture.commandName, } wrapper := newResticWrapper(ctx) - buffer := &bytes.Buffer{} + // setting the output via the package global setter could lead to some issues // when some tests are running in parallel. I should fix that at some point :-/ - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) err = wrapper.runCommand(fixture.commandName) - term.SetOutput(os.Stdout) + stdout := term.StopRecording() require.NoError(t, err) @@ -148,7 +146,7 @@ func TestFromConfigFileToCommandLine(t *testing.T) { t.SkipNow() } } - assert.Equal(t, expected, strings.TrimSpace(buffer.String())) + assert.Equal(t, expected, strings.TrimSpace(stdout)) }) if platform.IsWindows() { @@ -172,16 +170,15 @@ func TestFromConfigFileToCommandLine(t *testing.T) { legacyArgs: true, } wrapper := newResticWrapper(ctx) - buffer := &bytes.Buffer{} // setting the output via the package global setter could lead to some issues // when some tests are running in parallel. I should fix that at some point :-/ - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) err = wrapper.runCommand(fixture.commandName) - term.SetOutput(os.Stdout) + content := term.StopRecording() require.NoError(t, err) - assert.Equal(t, fixture.legacy, strings.TrimSpace(buffer.String())) + assert.Equal(t, fixture.legacy, strings.TrimSpace(content)) }) } }) diff --git a/logger.go b/logger.go index bcb0619c..40c5d25b 100644 --- a/logger.go +++ b/logger.go @@ -8,7 +8,6 @@ import ( "path/filepath" "slices" "strings" - "time" "github.com/creativeprojects/clog" "github.com/creativeprojects/resticprofile/constants" @@ -114,9 +113,8 @@ func getFileHandler(logfile string) (*clog.StandardLogHandler, io.Writer, error) } // create a platform aware log file appender - keepOpen, appender := true, appendFunc(nil) + var appender util.AsyncFileWriterAppendFunc if platform.IsWindows() { - keepOpen = false appender = func(dst []byte, c byte) []byte { switch c { case '\n': @@ -128,7 +126,11 @@ func getFileHandler(logfile string) (*clog.StandardLogHandler, io.Writer, error) } } - writer, err := newDeferredFileWriter(logfile, keepOpen, appender) + writer, err := util.NewAsyncFileWriter( + logfile, + util.WithAsyncFileAppendFunc(appender), + util.WithAsyncFilePerm(0644), + ) if err != nil { return nil, nil, err } @@ -173,140 +175,3 @@ func changeLevelFilter(level clog.LogLevel) { filter.SetLevel(level) } } - -// deferredFileWriter accumulates Write requests and writes them at a fixed rate (every 250 ms) -type deferredFileWriter struct { - done, flush chan chan error - data chan []byte -} - -func (d *deferredFileWriter) Close() error { - req := make(chan error) - d.done <- req - return <-req -} - -func (d *deferredFileWriter) Flush() error { - req := make(chan error) - d.flush <- req - return <-req -} - -func (d *deferredFileWriter) Write(data []byte) (n int, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic: %v", r) - } - }() - c := make([]byte, len(data)) - n = copy(c, data) - d.data <- c - return -} - -type appendFunc func(dst []byte, c byte) []byte - -func newDeferredFileWriter(filename string, keepOpen bool, appender appendFunc) (io.WriteCloser, error) { - d := &deferredFileWriter{ - flush: make(chan chan error), - done: make(chan chan error), - data: make(chan []byte, 64), - } - - var ( - buffer []byte - lastError error - file *os.File - ) - - closeFile := func() { - if file != nil { - lastError = file.Close() - file = nil - } - } - - flush := func(alsoEmpty bool) { - if len(buffer) == 0 && !alsoEmpty { - return - } - if file == nil { - file, lastError = os.OpenFile(filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) //nolint:gosec - } - if file != nil { - var written int - written, lastError = file.Write(buffer) - if written == len(buffer) { - buffer = buffer[:0] - } else { - buffer = buffer[written:] - } - } - if keepOpen { - _ = file.Sync() - } else { - closeFile() - } - } - - // test if we can create the file - buffer = make([]byte, 0, 4096) - flush(true) - - // data appending - addToBuffer := func(data []byte) { - buffer = append(buffer, data...) // fast path - } - if appender != nil { - addToBuffer = func(data []byte) { - for _, c := range data { - buffer = appender(buffer, c) - } - } - } - - addPendingData := func(size int) { - for ; size > 0; size-- { - select { - case data, ok := <-d.data: - if ok { - addToBuffer(data) - } else { - return // closed - } - default: - return // no-more-data - } - } - } - - // data transport - go func() { - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case data := <-d.data: - addToBuffer(data) - case <-ticker.C: - flush(false) - case req := <-d.flush: - addPendingData(1024) - flush(false) - req <- lastError - case req := <-d.done: - close(d.done) - close(d.flush) - close(d.data) - addPendingData(1024) - flush(false) - closeFile() - req <- lastError - return - } - } - }() - - return d, lastError -} diff --git a/logger_test.go b/logger_test.go index 43cb0afe..3ed5421c 100644 --- a/logger_test.go +++ b/logger_test.go @@ -53,8 +53,8 @@ func TestFileHandler(t *testing.T) { require.NoError(t, err) defer handler.Close() - require.Implements(t, (*term.Flusher)(nil), writer) - flusher := writer.(term.Flusher) + require.Implements(t, (*util.Flusher)(nil), writer) + flusher := writer.(util.Flusher) log := func(line string) { assert.NoError(t, handler.LogEntry(clog.LogEntry{Level: clog.LevelInfo, Format: line})) diff --git a/main.go b/main.go index c06edfc1..c621bb14 100644 --- a/main.go +++ b/main.go @@ -48,29 +48,37 @@ func main() { // run shutdown hooks just before returning an exit code defer shutdown.RunHooks() + terminal := term.Set(term.NewTerminal()) + args := os.Args[1:] _, flags, flagErr := loadFlags(args) if flagErr != nil && !errors.Is(flagErr, pflag.ErrHelp) { - term.Println(flagErr) - _ = displayHelpCommand(os.Stdout, commandContext{ + terminal.Println(flagErr) + _ = displayHelpCommand(commandContext{ ownCommands: ownCommands, Context: Context{ flags: flags, request: Request{ arguments: args, }, + terminal: terminal, }, }) exitCode = constants.ExitErrorInvalidFlags return } + // Configure terminal color + if flags.noAnsi || flags.theme == "none" { + terminal = term.Set(term.NewTerminal(term.WithColors(false))) // disable colors + } + if flags.wait { // keep the console running at the end of the program // so we can see what's going on defer func() { - term.Println("\n\nPress the Enter Key to continue...") - _, _ = fmt.Scanln() + terminal.Println("\n\nPress Enter to continue...") + _, _ = terminal.Scanln() }() } @@ -83,13 +91,14 @@ func main() { // help if flags.help || errors.Is(flagErr, pflag.ErrHelp) { - _ = displayHelpCommand(os.Stdout, commandContext{ + _ = displayHelpCommand(commandContext{ ownCommands: ownCommands, Context: Context{ flags: flags, request: Request{ arguments: args, }, + terminal: terminal, }, }) return @@ -106,7 +115,8 @@ func main() { setupRemoteLogger(flags, client) // also redirect the terminal through the client - term.SetAllOutput(term.NewRemoteTerm(client)) + remoteTerm := term.NewRemoteTerm(client) + terminal = term.Set(term.NewTerminal(term.WithStdout(remoteTerm), term.WithStderr(remoteTerm), term.WithColors(!flags.noAnsi))) } else { logTarget, commandOutput := "", "" if ctx != nil { @@ -118,6 +128,7 @@ func main() { flags.stderr = true } term.PrintToError = flags.stderr + terminal = term.Set(term.NewTerminal(term.WithStdout(os.Stderr), term.WithColors(!flags.noAnsi))) } if logTarget != "" && logTarget != "-" { if closer, err := setupTargetLogger(flags, logTarget, commandOutput); err == nil { @@ -171,6 +182,7 @@ func main() { command: flags.resticArgs[0], arguments: flags.resticArgs[1:], }, + terminal: terminal, } // try to load the config and setup logging for own command cfg, global, err := loadConfig(flags, true) @@ -203,6 +215,8 @@ func main() { return } + ctx = ctx.WithTerminal(terminal) + // check if we're running on battery if shouldStopOnBattery(ctx.stopOnBattery) { exitCode = constants.ExitRunningOnBattery @@ -285,8 +299,9 @@ func main() { if err != nil { clog.Error(err) if errors.Is(err, ErrProfileNotFound) { - displayProfiles(os.Stdout, ctx.config, flags) - displayGroups(os.Stdout, ctx.config, flags) + commandContext := commandContext{Context: *ctx} + displayProfiles(commandContext) + displayGroups(commandContext) } exitCode = constants.ExitGeneralError return diff --git a/own_commands.go b/own_commands.go index 6db21d04..a68d5dfc 100644 --- a/own_commands.go +++ b/own_commands.go @@ -2,8 +2,6 @@ package main import ( "fmt" - "io" - "os" "strings" "github.com/creativeprojects/clog" @@ -20,14 +18,14 @@ type ownCommand struct { name string description string longDescription string - pre func(*Context) error // pre-command action (for checking the context) - action func(io.Writer, commandContext) error // run command action - needConfiguration bool // true if the action needs a configuration file loaded - hide bool // don't display the command in help and completion - hideInCompletion bool // don't display the command in completion - noProfile bool // true if the command doesn't need a profile name - experimental bool // display a warning when using this command - flags map[string]string // own command flags should be simple enough to be handled manually for now + pre func(*Context) error // pre-command action (for checking the context) + action func(commandContext) error // run command action + needConfiguration bool // true if the action needs a configuration file loaded + hide bool // don't display the command in help and completion + hideInCompletion bool // don't display the command in completion + noProfile bool // true if the command doesn't need a profile name + experimental bool // display a warning when using this command + flags map[string]string // own command flags should be simple enough to be handled manually for now } // OwnCommands is a list of resticprofile commands @@ -68,7 +66,7 @@ func (o *OwnCommands) Run(ctx *Context) error { if command.experimental { clog.Warningf("%s: this command is experimental and its behaviour may change in the future", ctx.request.command) } - return command.action(os.Stdout, commandContext{ + return command.action(commandContext{ ownCommands: o, Context: *ctx, }) diff --git a/own_commands_test.go b/own_commands_test.go index 09334623..386c8e9d 100644 --- a/own_commands_test.go +++ b/own_commands_test.go @@ -2,10 +2,10 @@ package main import ( "errors" - "io" "strings" "testing" + "github.com/creativeprojects/resticprofile/term" "github.com/stretchr/testify/assert" ) @@ -40,15 +40,15 @@ func fakeCommands() *OwnCommands { return ownCommands } -func firstCommand(_ io.Writer, _ commandContext) error { +func firstCommand(_ commandContext) error { return errors.New("first") } -func secondCommand(_ io.Writer, _ commandContext) error { +func secondCommand(_ commandContext) error { return errors.New("second") } -func thirdCommand(_ io.Writer, _ commandContext) error { +func thirdCommand(_ commandContext) error { return errors.New("third") } @@ -58,13 +58,15 @@ func pre(_ *Context) error { func TestDisplayOwnCommands(t *testing.T) { buffer := &strings.Builder{} - displayOwnCommands(buffer, commandContext{ownCommands: fakeCommands()}) + terminal := term.NewTerminal(term.WithStdout(buffer)) + displayOwnCommands(commandContext{ownCommands: fakeCommands(), Context: Context{terminal: terminal}}) assert.Equal(t, " first first first\n second second second\n", buffer.String()) } func TestDisplayOwnCommand(t *testing.T) { buffer := &strings.Builder{} - displayOwnCommandHelp(buffer, "second", commandContext{ownCommands: fakeCommands()}) + terminal := term.NewTerminal(term.WithStdout(buffer)) + displayOwnCommandHelp(commandContext{ownCommands: fakeCommands(), Context: Context{terminal: terminal}}, "second") assert.Equal(t, `Purpose: second second Usage: diff --git a/schedule/handler_crond.go b/schedule/handler_crond.go index 4e819e8b..48b9e31f 100644 --- a/schedule/handler_crond.go +++ b/schedule/handler_crond.go @@ -12,6 +12,7 @@ import ( "github.com/creativeprojects/resticprofile/crond" "github.com/creativeprojects/resticprofile/platform" "github.com/creativeprojects/resticprofile/shell" + "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/user" "github.com/spf13/afero" ) @@ -63,7 +64,7 @@ func (h *HandlerCrond) DisplaySchedules(profile, command string, schedules []str if err != nil { return err } - displayParsedSchedules(profile, command, events) + displayParsedSchedules(term.Get(), profile, command, events) return nil } diff --git a/schedule/handler_darwin.go b/schedule/handler_darwin.go index 5912d0a2..db6bfeb6 100644 --- a/schedule/handler_darwin.go +++ b/schedule/handler_darwin.go @@ -69,7 +69,7 @@ func (h *HandlerLaunchd) DisplaySchedules(profile, command string, schedules []s if err != nil { return err } - displayParsedSchedules(profile, command, events) + displayParsedSchedules(term.Get(), profile, command, events) return nil } diff --git a/schedule/handler_systemd.go b/schedule/handler_systemd.go index b5bccc83..47b66b2a 100644 --- a/schedule/handler_systemd.go +++ b/schedule/handler_systemd.go @@ -91,7 +91,7 @@ func (h *HandlerSystemd) ParseSchedules(schedules []string) ([]*calendar.Event, // DisplaySchedules displays the schedules through the systemd-analyze command func (h *HandlerSystemd) DisplaySchedules(profile, command string, schedules []string) error { - return displaySystemdSchedules(profile, command, schedules) + return displaySystemdSchedules(term.Get(), profile, command, schedules) } // DisplayStatus displays the status of all the timers installed on that profile. Example: @@ -597,7 +597,7 @@ func systemctlCommand(args ...string) (*exec.Cmd, error) { return exec.CommandContext(context.TODO(), binary, args...), nil } -func displaySystemdSchedules(profile, command string, schedules []string) error { +func displaySystemdSchedules(terminal *term.Terminal, profile, command string, schedules []string) error { binary, err := exec.LookPath(analyzeBinary) if err != nil { return fmt.Errorf("cannot find %q: %w", analyzeBinary, err) @@ -607,17 +607,17 @@ func displaySystemdSchedules(profile, command string, schedules []string) error if schedule == "" { return errors.New("empty schedule") } - displayHeader(profile, command, index+1, len(schedules)) + displayHeader(terminal, profile, command, index+1, len(schedules)) cmd := exec.CommandContext(context.TODO(), binary, "calendar", schedule) - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() err = cmd.Run() if err != nil { return err } } - term.Print(platform.LineSeparator) + terminal.Print(platform.LineSeparator) return nil } diff --git a/schedule/handler_systemd_test.go b/schedule/handler_systemd_test.go index 615ff652..defa20df 100644 --- a/schedule/handler_systemd_test.go +++ b/schedule/handler_systemd_test.go @@ -262,16 +262,15 @@ func TestPermissionToSystemd(t *testing.T) { } func TestDisplaySystemdSchedulesWithEmpty(t *testing.T) { - err := displaySystemdSchedules("profile", "command", []string{""}) + err := displaySystemdSchedules(term.NewTerminal(), "profile", "command", []string{""}) require.Error(t, err) } func TestDisplaySystemdSchedules(t *testing.T) { buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + terminal := term.NewTerminal(term.WithStdout(buffer)) - err := displaySystemdSchedules("profile", "command", []string{"daily"}) + err := displaySystemdSchedules(terminal, "profile", "command", []string{"daily"}) require.NoError(t, err) output := buffer.String() diff --git a/schedule/handler_windows.go b/schedule/handler_windows.go index 988b94ee..54111f70 100644 --- a/schedule/handler_windows.go +++ b/schedule/handler_windows.go @@ -10,6 +10,7 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/schtasks" "github.com/creativeprojects/resticprofile/shell" + "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/user" ) @@ -38,7 +39,7 @@ func (h *HandlerWindows) DisplaySchedules(profile, command string, schedules []s if err != nil { return err } - displayParsedSchedules(profile, command, events) + displayParsedSchedules(term.Get(), profile, command, events) return nil } diff --git a/schedule/schedules.go b/schedule/schedules.go index 7daf73cb..3f261882 100644 --- a/schedule/schedules.go +++ b/schedule/schedules.go @@ -11,16 +11,16 @@ import ( "github.com/creativeprojects/resticprofile/term" ) -func displayHeader(profile, command string, index, total int) { - term.Print(platform.LineSeparator) +func displayHeader(terminal *term.Terminal, profile, command string, index, total int) { + terminal.Print(platform.LineSeparator) header := fmt.Sprintf("Profile (or Group) %s: %s schedule", profile, command) if total > 1 { header += fmt.Sprintf(" %d/%d", index, total) } - term.Print(header) - term.Print(platform.LineSeparator) - term.Print(strings.Repeat("=", len(header))) - term.Print(platform.LineSeparator) + terminal.Print(header) + terminal.Print(platform.LineSeparator) + terminal.Print(strings.Repeat("=", len(header))) + terminal.Print(platform.LineSeparator) } // parseSchedules creates a *calendar.Event from a string @@ -40,20 +40,20 @@ func parseSchedules(schedules []string) ([]*calendar.Event, error) { return events, nil } -func displayParsedSchedules(profile, command string, events []*calendar.Event) { +func displayParsedSchedules(terminal *term.Terminal, profile, command string, events []*calendar.Event) { now := time.Now().Round(time.Second) for index, event := range events { - displayHeader(profile, command, index+1, len(events)) + displayHeader(terminal, profile, command, index+1, len(events)) next := event.Next(now) - term.Printf(" Original form: %s\n", event.Input()) - term.Printf("Normalized form: %s\n", event.String()) + terminal.Printf(" Original form: %s\n", event.Input()) + terminal.Printf("Normalized form: %s\n", event.String()) if next.IsZero() { - term.Print(" Next elapse: never\n") + terminal.Print(" Next elapse: never\n") continue } - term.Printf(" Next elapse: %s\n", next.Format(time.UnixDate)) - term.Printf(" (in UTC): %s\n", next.UTC().Format(time.UnixDate)) - term.Printf(" From now: %s left\n", next.Sub(now)) + terminal.Printf(" Next elapse: %s\n", next.Format(time.UnixDate)) + terminal.Printf(" (in UTC): %s\n", next.UTC().Format(time.UnixDate)) + terminal.Printf(" From now: %s left\n", next.Sub(now)) } - term.Print(platform.LineSeparator) + terminal.Print(platform.LineSeparator) } diff --git a/schedule/schedules_test.go b/schedule/schedules_test.go index 1c014fef..6a7357c9 100644 --- a/schedule/schedules_test.go +++ b/schedule/schedules_test.go @@ -2,7 +2,6 @@ package schedule import ( "bytes" - "os" "testing" "github.com/creativeprojects/resticprofile/term" @@ -37,11 +36,10 @@ func TestDisplayParseSchedules(t *testing.T) { events, err := parseSchedules([]string{"daily"}) require.NoError(t, err) - buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) - displayParsedSchedules("profile", "command", events) + displayParsedSchedules(terminal, "profile", "command", events) output := buffer.String() assert.Contains(t, output, "Original form: daily\n") assert.Contains(t, output, "Normalized form: *-*-* 00:00:00\n") @@ -51,11 +49,10 @@ func TestDisplayParseSchedulesWillNeverRun(t *testing.T) { events, err := parseSchedules([]string{"2020-01-01"}) require.NoError(t, err) - buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) - displayParsedSchedules("profile", "command", events) + displayParsedSchedules(terminal, "profile", "command", events) output := buffer.String() assert.Contains(t, output, "Next elapse: never\n") } @@ -64,11 +61,10 @@ func TestDisplayParseSchedulesIndexAndTotal(t *testing.T) { events, err := parseSchedules([]string{"daily", "monthly", "yearly"}) require.NoError(t, err) - buffer := &bytes.Buffer{} - term.SetOutput(buffer) - defer term.SetOutput(os.Stdout) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) - displayParsedSchedules("profile", "command", events) + displayParsedSchedules(terminal, "profile", "command", events) output := buffer.String() assert.Contains(t, output, "schedule 1/3") assert.Contains(t, output, "schedule 2/3") diff --git a/serve.go b/serve.go index ef043f44..cc254011 100644 --- a/serve.go +++ b/serve.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "os" "os/signal" @@ -19,7 +18,7 @@ import ( "github.com/creativeprojects/resticprofile/ssh" ) -func serveCommand(w io.Writer, cmdCtx commandContext) error { +func serveCommand(cmdCtx commandContext) error { if len(cmdCtx.flags.resticArgs) < 2 { return fmt.Errorf("missing argument: port") } @@ -78,7 +77,7 @@ func serveProfiles(port string, config *config.Config, quit chan os.Signal) erro return nil } -func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { +func sendProfileCommand(cmdCtx commandContext) error { if len(cmdCtx.flags.resticArgs) < 2 { return fmt.Errorf("missing argument: remote name") } diff --git a/shell/util.go b/shell/util.go new file mode 100644 index 00000000..700b6a05 --- /dev/null +++ b/shell/util.go @@ -0,0 +1,45 @@ +package shell + +import ( + "bufio" + "bytes" + "io" + + "github.com/creativeprojects/resticprofile/platform" +) + +var ( + bogusPrefix = []byte("\r\x1b[2K") +) + +func LineOutputFilter(output io.Writer, included func(line []byte) bool) io.WriteCloser { + eol := []byte(platform.LineSeparator) + + reader, writer := io.Pipe() + + go func() { + var err error + defer func() { + _ = reader.CloseWithError(err) + }() + + scanner := bufio.NewScanner(reader) + for err == nil && scanner.Scan() { + line := bytes.TrimPrefix(scanner.Bytes(), bogusPrefix) + if !included(line) { + continue + } + if err == nil { + _, err = output.Write(line) + } + if err == nil { + _, err = output.Write(eol) + } + } + if err == nil { + err = scanner.Err() + } + }() + + return writer +} diff --git a/term/default.go b/term/default.go new file mode 100644 index 00000000..340b0e07 --- /dev/null +++ b/term/default.go @@ -0,0 +1,17 @@ +package term + +var defaultTerminal *Terminal + +// Get returns the default terminal. It will be initialized on first use with NewTerminal() if not set before. +func Get() *Terminal { + if defaultTerminal == nil { + defaultTerminal = NewTerminal() + } + return defaultTerminal +} + +// Set stores the default terminal, and returns the terminal reference for chaining. +func Set(t *Terminal) *Terminal { + defaultTerminal = t + return defaultTerminal +} diff --git a/term/nil_reader.go b/term/nil_reader.go new file mode 100644 index 00000000..bfea236a --- /dev/null +++ b/term/nil_reader.go @@ -0,0 +1,7 @@ +package term + +type nilReader struct{} + +func (nilReader) Read(p []byte) (int, error) { + return 0, nil +} diff --git a/term/recorder.go b/term/recorder.go new file mode 100644 index 00000000..756bf711 --- /dev/null +++ b/term/recorder.go @@ -0,0 +1,39 @@ +package term + +import ( + "errors" + "fmt" + "io" + "os" + "sync" +) + +type Recorder struct { + inputWriter *os.File + inputReader *os.File + wg sync.WaitGroup +} + +func NewRecorder(output io.Writer) (*Recorder, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("cannot create OS pipe: %w", err) + } + recorder := &Recorder{ + inputReader: r, + inputWriter: w, + } + recorder.wg.Go(func() { + _, _ = io.Copy(output, recorder.inputReader) + }) + return recorder, nil +} + +func (recorder *Recorder) Close() error { + var err error + err = errors.Join(err, recorder.inputWriter.Close()) + recorder.wg.Wait() + // now that we finished reading, we can close the reader side + err = errors.Join(err, recorder.inputReader.Close()) + return err +} diff --git a/term/recording.go b/term/recording.go new file mode 100644 index 00000000..5fe6e86a --- /dev/null +++ b/term/recording.go @@ -0,0 +1,75 @@ +package term + +import ( + "bytes" + "io" + "sync" + + "github.com/creativeprojects/resticprofile/util" +) + +type outputRecording struct { + lock sync.Mutex + buffer *bytes.Buffer + writer io.Writer + output, error io.Writer +} + +type RecordMode uint8 + +const ( + RecordOutput RecordMode = iota + RecordError + RecordBoth +) + +func (r *outputRecording) StartRecording(mode RecordMode) { + r.lock.Lock() + defer r.lock.Unlock() + if r.buffer != nil { + return + } + + r.buffer = new(bytes.Buffer) + r.writer = util.NewSyncWriterMutex(r.buffer, &r.lock) + + if mode != RecordError { + r.output = GetOutput() + setOutput(r.writer) + } + if mode != RecordOutput { + r.error = GetErrorOutput() + setErrorOutput(r.writer) + } +} + +func (r *outputRecording) ReadRecording() (content string) { + r.lock.Lock() + defer r.lock.Unlock() + if r.buffer != nil { + content = r.buffer.String() + r.buffer.Reset() + } + return +} + +func (r *outputRecording) StopRecording() (content string) { + r.lock.Lock() + defer r.lock.Unlock() + if r.buffer != nil { + if r.output != nil && r.writer == GetOutput() { + setOutput(r.output) + r.output = nil + } + + if r.error != nil && r.writer == GetErrorOutput() { + setErrorOutput(r.error) + r.error = nil + } + + content = r.buffer.String() + r.writer = nil + r.buffer = nil + } + return +} diff --git a/term/term.go b/term/term.go index 5570c5fa..04df3e84 100644 --- a/term/term.go +++ b/term/term.go @@ -2,65 +2,164 @@ package term import ( "bufio" + "bytes" "fmt" "io" "os" "strings" - "sync" + "sync/atomic" + "time" - colorable "github.com/mattn/go-colorable" + "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/ansi" + "github.com/mattn/go-colorable" "golang.org/x/term" ) var ( - terminalOutput io.Writer = os.Stdout - errorOutput io.Writer = os.Stderr - PrintToError = false + termOutput atomic.Pointer[io.Writer] + errorOutput atomic.Pointer[io.Writer] + colorOutput atomic.Pointer[io.Writer] + enableColors atomic.Bool + statusChannel = make(chan []string) + statusWaitChannel = make(chan chan bool) + PrintToError = false ) -// Flusher allows a Writer to declare it may buffer content that can be flushed -type Flusher interface { - // Flush writes any pending bytes to output - Flush() error +const ( + StatusFPS = 10 +) + +func init() { + enableColors.Store(true) + go handleStatus() + // must be last + { + setOutput(os.Stdout) + setErrorOutput(os.Stderr) + } } -// AskYesNo prompts the user for a message asking for a yes/no answer -func AskYesNo(reader io.Reader, message string, defaultAnswer bool) bool { - if !strings.HasSuffix(message, "?") { - message += "?" +func handleStatus() { + ticker := time.NewTicker(time.Second / StatusFPS) + defer ticker.Stop() + + var waiting []chan bool + respondWaiting := func(result bool) { + for _, request := range waiting { + request <- result + close(request) + } + waiting = nil } - var question, input string - if defaultAnswer { - question = "(Y/n)" - input = "y" - } else { - question = "(y/N)" - input = "n" - } - fmt.Printf("%s %s: ", message, question) - scanner := bufio.NewScanner(reader) - if scanner.Scan() { - input = strings.TrimSpace(strings.ToLower(scanner.Text())) - if len(input) > 1 { - // take only the first character - input = input[:1] + defer respondWaiting(false) + + var newStatus, status []string + buffer := &bytes.Buffer{} + for { + select { + case lines := <-statusChannel: + newStatus = lines + + case request := <-statusWaitChannel: + waiting = append(waiting, request) + + case <-ticker.C: + if status != nil && OutputIsTerminal() { + width, height := OsStdoutTerminalSize() + noAnsi := !IsColorableOutput() + if height < 1 { + continue + } else if noAnsi { + newStatus = newStatus[1:] // strip first empty line + height = 1 + } + if width >= 60 { + width -= 2 + } else if width >= 80 { + width -= 4 // right margin + } + + last := truncate(status, height) + printable := truncate(newStatus, height) + removedLines := len(last) - len(printable) + if removedLines > 0 { + filler := make([]string, removedLines, removedLines+len(printable)) + printable = append(filler, printable...) + } + + if len(printable) > 0 { + buffer.Reset() + for index, line := range printable { + runes := []rune(strings.ReplaceAll(line, "\n", " ")) + _, maxIndex := ansi.RunesLength(runes, width) + runes = truncate(runes, maxIndex) + + if noAnsi { + if remaining := width - len(runes); remaining > 0 { + for remaining > 0 { + runes = append(runes, ' ') + remaining-- + } + } + _, _ = fmt.Fprintf(buffer, "\r%s\r", string(runes)) + } else { + eol := "\n" + if index+1 == len(printable) { + eol = "\r" + } + _, _ = fmt.Fprintf(buffer, "\r%s%s%s%s", ansi.ClearLine, string(runes), ansi.Reset, eol) + } + } + + if !noAnsi { + buffer.WriteString(ansi.CursorUpLeftN(len(printable) - 1)) + } + + _, _ = buffer.WriteTo(getColorableOutput()) + buffer.Reset() + } + } + status = newStatus + respondWaiting(true) } } +} - if input == "" { - return defaultAnswer +func truncate[E any](src []E, maxLength int) []E { + if len(src) > maxLength { + return src[:maxLength] } - if input == "y" { - return true + return src +} + +// SetStatus sets a status line(s) that is printed when the output is an interactive terminal +// +// Deprecated: use term.Terminal instead +func SetStatus(line []string) { + // Clone lines and add empty line on top (= cursor position after printing status) + if line != nil { + line = append([]string{""}, line...) } - return false + statusChannel <- line +} + +// WaitForStatus blocks until the previously provided status was applied +// +// Deprecated: use term.Terminal instead +func WaitForStatus() bool { + request := make(chan bool, 1) + statusWaitChannel <- request + return <-request } // ReadPassword reads a password without echoing it to the terminal. +// +// Deprecated: use term.Terminal instead func ReadPassword() (string, error) { stdin := fdToInt(os.Stdin.Fd()) if !term.IsTerminal(stdin) { - return ReadLine() + return readLine() } line, err := term.ReadPassword(stdin) _, _ = fmt.Fprintln(os.Stderr) @@ -70,8 +169,8 @@ func ReadPassword() (string, error) { return string(line), nil } -// ReadLine reads some input -func ReadLine() (string, error) { +// readLine reads some input +func readLine() (string, error) { buf := bufio.NewReader(os.Stdin) line, err := buf.ReadString('\n') if err != nil { @@ -81,12 +180,15 @@ func ReadLine() (string, error) { } // OsStdoutIsTerminal returns true as os.Stdout is a terminal session +// +// Deprecated: use term.Terminal instead func OsStdoutIsTerminal() bool { - fd := fdToInt(os.Stdout.Fd()) - return term.IsTerminal(fd) + return isTerminal(os.Stdout) } -// OsStdoutTerminalSize returns the current width and height of os.Stdout +// OsStdoutTerminalSize returns the width and height of the terminal session +// +// Deprecated: use term.Terminal instead func OsStdoutTerminalSize() (width, height int) { fd := fdToInt(os.Stdout.Fd()) var err error @@ -97,98 +199,170 @@ func OsStdoutTerminalSize() (width, height int) { return } -func fdToInt(fd uintptr) int { - return int(fd) //nolint:gosec +// OutputIsTerminal returns true if GetOutput sends to an interactive terminal +// +// Deprecated: use term.Terminal instead +func OutputIsTerminal() bool { + return GetOutput() == os.Stdout && OsStdoutIsTerminal() } -type LockedWriter struct { - writer io.Writer - mutex *sync.Mutex +// SetOutput changes the default output for the Print* functions +// +// Deprecated: use term.Terminal instead +func SetOutput(w io.Writer) { + if w == os.Stdout && isTerminal(os.Stdout) { + setOutput(os.Stdout) + } else { + setOutput(util.NewSyncWriter(w)) + } } -func (w *LockedWriter) Write(p []byte) (n int, err error) { - w.mutex.Lock() - defer w.mutex.Unlock() - return w.writer.Write(p) +func setOutput(w io.Writer) { + if w == nil { + w = io.Discard + } + termOutput.Store(&w) + colorOutput.Store(nil) + SetStatus(nil) } -func (w *LockedWriter) Flush() (err error) { - w.mutex.Lock() - defer w.mutex.Unlock() - if f, ok := w.writer.(Flusher); ok { - err = f.Flush() +// GetOutput returns the default output of the Print* functions +// +// Deprecated: use term.Terminal instead +func GetOutput() (out io.Writer) { + if v := termOutput.Load(); v != nil { + out = *v } return } -// SetOutput changes the default output for the Print* functions -func SetOutput(w io.Writer) { - terminalOutput = &LockedWriter{writer: w, mutex: new(sync.Mutex)} +// getColorableOutput returns an output supporting ANSI color if output is a terminal +func getColorableOutput() (out io.Writer) { + if v := colorOutput.Load(); v != nil { + out = *v + } + if out == nil { + if IsColorableOutput() { + out = colorable.NewColorable(os.Stdout) + } else { + out = colorable.NewNonColorable(outputWriter()) + } + colorOutput.Store(&out) + } + return out } -// GetOutput returns the default output of the Print* functions -func GetOutput() io.Writer { - return terminalOutput +// IsColorableOutput tells whether GetColorableOutput supports ANSI color (and control characters) or discards ANSI +// +// Deprecated: use term.Terminal instead +func IsColorableOutput() bool { + return enableColors.Load() && OutputIsTerminal() } -// GetColorableOutput returns an output supporting ANSI color if output is a terminal -func GetColorableOutput() io.Writer { - out := GetOutput() - if out == os.Stdout && OsStdoutIsTerminal() { - return colorable.NewColorable(os.Stdout) +// SetErrorOutput changes the error output for the Print* functions +// +// Deprecated: use term.Terminal instead +func SetErrorOutput(w io.Writer) { + if w == os.Stderr && isTerminal(os.Stderr) { + setErrorOutput(os.Stderr) + } else { + setErrorOutput(util.NewSyncWriter(w)) } - return colorable.NewNonColorable(out) } -// SetErrorOutput changes the error output for the Print* functions -func SetErrorOutput(w io.Writer) { - errorOutput = &LockedWriter{writer: w, mutex: new(sync.Mutex)} +func setErrorOutput(w io.Writer) { + if w == nil { + w = io.Discard + } + errorOutput.Store(&w) } // GetErrorOutput returns the error output of the Print* functions -func GetErrorOutput() io.Writer { - return errorOutput +// +// Deprecated: use term.Terminal instead +func GetErrorOutput() (out io.Writer) { + if v := errorOutput.Load(); v != nil { + out = *v + } + return } // SetAllOutput changes the default and error output for the Print* functions +// +// Deprecated: use term.Terminal instead func SetAllOutput(w io.Writer) { - single := new(sync.Mutex) - terminalOutput = &LockedWriter{writer: w, mutex: single} - errorOutput = &LockedWriter{writer: w, mutex: single} + SetOutput(w) + setErrorOutput(GetOutput()) } // FlushAllOutput flushes all buffered output (if supported by the underlying Writer). +// +// Deprecated: use term.Terminal instead func FlushAllOutput() { - for _, writer := range []io.Writer{terminalOutput, errorOutput} { - if f, ok := writer.(Flusher); ok { - _ = f.Flush() - } + for _, writer := range []io.Writer{GetOutput(), GetErrorOutput()} { + _, _ = util.FlushWriter(writer) } } -// Print formats using the default formats for its operands and writes to standard output. -// Spaces are added between operands when neither is a string. -// It returns the number of bytes written and any write error encountered. -func Print(a ...any) (n int, err error) { - return fmt.Fprint(outputWriter(), a...) +var recording = &outputRecording{} + +// Deprecated: use term.Terminal instead +func StartRecording(mode RecordMode) { + recording.lock.Lock() + defer recording.lock.Unlock() + if recording.buffer != nil { + return + } + + recording.buffer = new(bytes.Buffer) + recording.writer = util.NewSyncWriterMutex(recording.buffer, &recording.lock) + + if mode != RecordError { + recording.output = GetOutput() + setOutput(recording.writer) + } + if mode != RecordOutput { + recording.error = GetErrorOutput() + setErrorOutput(recording.writer) + } } -// Println formats using the default formats for its operands and writes to standard output. -// Spaces are always added between operands and a newline is appended. -// It returns the number of bytes written and any write error encountered. -func Println(a ...any) (n int, err error) { - return fmt.Fprintln(outputWriter(), a...) +// Deprecated: use term.Terminal instead +func ReadRecording() (content string) { + recording.lock.Lock() + defer recording.lock.Unlock() + if recording.buffer != nil { + content = recording.buffer.String() + recording.buffer.Reset() + } + return } -// Printf formats according to a format specifier and writes to standard output. -// It returns the number of bytes written and any write error encountered. -func Printf(format string, a ...any) (n int, err error) { - return fmt.Fprintf(outputWriter(), format, a...) +// Deprecated: use term.Terminal instead +func StopRecording() (content string) { + recording.lock.Lock() + defer recording.lock.Unlock() + if recording.buffer != nil { + if recording.output != nil && recording.writer == GetOutput() { + setOutput(recording.output) + recording.output = nil + } + + if recording.error != nil && recording.writer == GetErrorOutput() { + setErrorOutput(recording.error) + recording.error = nil + } + + content = recording.buffer.String() + recording.writer = nil + recording.buffer = nil + } + return } func outputWriter() io.Writer { if PrintToError { - return errorOutput + return GetErrorOutput() } - return terminalOutput + return GetOutput() } diff --git a/term/term_test.go b/term/term_test.go deleted file mode 100644 index 9215747c..00000000 --- a/term/term_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package term - -import ( - "bytes" - "log" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -type askYesNoTestData struct { - input string - defaultAnswer bool - expected bool -} - -func TestAskYesNo(t *testing.T) { - testData := []askYesNoTestData{ - // Empty answer => will follow the defaultAnswer - {"", true, true}, - {"", false, false}, - {"\n", true, true}, - {"\n", false, false}, - {"\r\n", true, true}, - {"\r\n", false, false}, - // Garbage answer => will always return false - {"aa", true, false}, - {"aa", false, false}, - {"aa\n", true, false}, - {"aa\n", false, false}, - {"aa\r\n", true, false}, - {"aa\r\n", false, false}, - // Answer yes - {"y", true, true}, - {"y", false, true}, - {"y\n", true, true}, - {"y\n", false, true}, - {"y\r\n", true, true}, - {"y\r\n", false, true}, - // Full answer yes - {"yes", true, true}, - {"yes", false, true}, - {"yes\n", true, true}, - {"yes\n", false, true}, - {"yes\r\n", true, true}, - {"yes\r\n", false, true}, - // Answer no - {"n", true, false}, - {"n", false, false}, - {"n\n", true, false}, - {"n\n", false, false}, - {"n\r\n", true, false}, - {"n\r\n", false, false}, - // Full answer no - {"no", true, false}, - {"no", false, false}, - {"no\n", true, false}, - {"no\n", false, false}, - {"no\r\n", true, false}, - {"no\r\n", false, false}, - } - for _, testItem := range testData { - result := AskYesNo( - bytes.NewBufferString(testItem.input), - "message", - testItem.defaultAnswer, - ) - assert.Equalf(t, testItem.expected, result, "when input was %q", testItem.input) - } -} - -func ExamplePrint() { - PrintToError = false - SetOutput(os.Stdout) - _, err := Print("ExamplePrint") - if err != nil { - log.Fatal(err) - } - // Output: ExamplePrint -} - -func TestCanRedirectTermOutput(t *testing.T) { - PrintToError = false - message := "TestCanRedirectTermOutput" - outputBuffer := &bytes.Buffer{} - errorBuffer := &bytes.Buffer{} - SetOutput(outputBuffer) - SetErrorOutput(errorBuffer) - _, err := Print(message) - assert.NoError(t, err) - assert.Equal(t, message, outputBuffer.String()) - assert.Empty(t, errorBuffer.String()) -} - -func TestCanRedirectTermErrorOutput(t *testing.T) { - PrintToError = true - message := "TestCanRedirectTermOutput" - outputBuffer := &bytes.Buffer{} - errorBuffer := &bytes.Buffer{} - SetOutput(outputBuffer) - SetErrorOutput(errorBuffer) - _, err := Print(message) - assert.NoError(t, err) - assert.Equal(t, message, errorBuffer.String()) - assert.Empty(t, outputBuffer.String()) -} diff --git a/term/terminal.go b/term/terminal.go new file mode 100644 index 00000000..a14834e1 --- /dev/null +++ b/term/terminal.go @@ -0,0 +1,193 @@ +package term + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" + + "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/maybe" + "github.com/mattn/go-colorable" + "golang.org/x/term" +) + +type Terminal struct { + stdin io.Reader + stdout io.Writer // final writer + stderr io.Writer // final writer + enableColors maybe.Bool + colorableStdout io.Writer // colorable writer + colorableStderr io.Writer // colorable writer +} + +func NewTerminal(options ...terminalOption) *Terminal { + t := &Terminal{ + stdin: os.Stdin, + stdout: os.Stdout, + stderr: os.Stderr, + } + + for _, option := range options { + option(t) + } + + t.colorableStdout = t.getColorableWriter(t.stdout) + t.colorableStderr = t.getColorableWriter(t.stderr) + + return t +} + +// AskYesNo prompts the user for a message asking for a yes/no answer +func (t *Terminal) AskYesNo(message string, defaultAnswer bool) bool { + if !strings.HasSuffix(message, "?") { + message += "?" + } + var question, input string + if defaultAnswer { + question = "(Y/n)" + input = "y" + } else { + question = "(y/N)" + input = "n" + } + _, _ = t.Printf("%s %s: ", message, question) + scanner := bufio.NewScanner(t.stdin) + if scanner.Scan() { + input = strings.TrimSpace(strings.ToLower(scanner.Text())) + if len(input) > 1 { + // take only the first character + input = input[:1] + } + } + + if input == "" { + return defaultAnswer + } + if input == "y" { + return true + } + return false +} + +// ReadPassword reads a password without echoing it to the terminal. +func (t *Terminal) ReadPassword() (string, error) { + stdin, ok := t.stdin.(*os.File) + if !ok || !isTerminal(stdin) { + return t.readLine() + } + line, err := term.ReadPassword(fdToInt(stdin.Fd())) + if err != nil { + return "", fmt.Errorf("failed to read password: %w", err) + } + return string(line), nil +} + +// ReadLine reads some input +func (t *Terminal) readLine() (string, error) { + buf := bufio.NewReader(t.stdin) + line, err := buf.ReadString('\n') + if err != nil { + return "", fmt.Errorf("failed to read line: %w", err) + } + return strings.TrimSpace(line), nil +} + +// StdoutIsTerminal returns true as stdout is a terminal session +func (t *Terminal) StdoutIsTerminal() bool { + return isTerminalWriter(t.stdout) +} + +// StderrIsTerminal returns true as stderr is a terminal session +func (t *Terminal) StderrIsTerminal() bool { + return isTerminalWriter(t.stderr) +} + +func (t *Terminal) getColorableWriter(w io.Writer) io.Writer { + if file, ok := w.(*os.File); ok && t.enableColors.IsTrueOrUndefined() && (isTerminal(file) || t.enableColors.IsTrue()) { + return colorable.NewColorable(file) + } + if t.enableColors.IsTrue() { + // output is not a file, but the user doesn't want to strip the colors + return w + } + return colorable.NewNonColorable(w) +} + +// Size returns the width and height of the terminal session +func (t *Terminal) Size() (width, height int) { + fd := fdToInt(os.Stdout.Fd()) + var err error + width, height, err = term.GetSize(fd) + if err != nil { + width, height = 0, 0 + } + return +} + +// FlushAllOutput flushes all buffered output (if supported by the underlying Writer). +func (t *Terminal) FlushAllOutput() { + for _, writer := range []io.Writer{t.colorableStdout, t.colorableStderr, t.stdout, t.stderr} { + _, _ = util.FlushWriter(writer) + } +} + +// Print formats using the default formats for its operands and writes to standard output. +// Spaces are added between operands when neither is a string. +// It returns the number of bytes written and any write error encountered. +func (t *Terminal) Print(a ...any) (n int, err error) { + return fmt.Fprint(t.colorableStdout, a...) +} + +// Println formats using the default formats for its operands and writes to standard output. +// Spaces are always added between operands and a newline is appended. +// It returns the number of bytes written and any write error encountered. +func (t *Terminal) Println(a ...any) (n int, err error) { + return fmt.Fprintln(t.colorableStdout, a...) +} + +// Printf formats according to a format specifier and writes to standard output. +// It returns the number of bytes written and any write error encountered. +func (t *Terminal) Printf(format string, a ...any) (n int, err error) { + return fmt.Fprintf(t.colorableStdout, format, a...) +} + +func (t *Terminal) Scanln(a ...any) (n int, err error) { + return fmt.Fscanln(t.stdin, a...) +} + +// Write implements the io.Writer interface, writing to the terminal's stdout. +func (t *Terminal) Write(p []byte) (n int, err error) { + return t.colorableStdout.Write(p) +} + +func (t *Terminal) Stdout() io.Writer { + return t.colorableStdout +} + +func (t *Terminal) Stderr() io.Writer { + return t.colorableStderr +} + +func isTerminalWriter(w io.Writer) bool { + file, ok := w.(*os.File) + if !ok { + return false + } + return isTerminal(file) +} + +func isTerminal(file *os.File) bool { + if file == nil { + return false + } + fd := fdToInt(file.Fd()) + return term.IsTerminal(fd) +} + +func fdToInt(fd uintptr) int { + return int(fd) //nolint:gosec +} + +var _ io.Writer = (*Terminal)(nil) diff --git a/term/terminal_option.go b/term/terminal_option.go new file mode 100644 index 00000000..81dd676c --- /dev/null +++ b/term/terminal_option.go @@ -0,0 +1,62 @@ +package term + +import ( + "io" + "os" + + "github.com/creativeprojects/resticprofile/util" + "github.com/creativeprojects/resticprofile/util/maybe" +) + +type terminalOption func(t *Terminal) + +func WithStdin(stdin io.Reader) terminalOption { + if stdin == nil { + stdin = nilReader{} + } + return func(t *Terminal) { + t.stdin = stdin + } +} + +func WithStdout(stdout io.Writer) terminalOption { + if stdout == nil { + stdout = io.Discard + } + if stdout != os.Stdout && stdout != os.Stderr { + stdout = util.NewSyncWriter(stdout) + } + return func(t *Terminal) { + t.stdout = stdout + } +} + +func WithStderr(stderr io.Writer) terminalOption { + if stderr == nil { + stderr = io.Discard + } + if stderr != os.Stdout && stderr != os.Stderr { + stderr = util.NewSyncWriter(stderr) + } + return func(t *Terminal) { + t.stderr = stderr + } +} + +func WithColors(enable bool) terminalOption { + return func(t *Terminal) { + t.enableColors = maybe.SetBool(enable) + } +} + +func WithStdoutRecorder(recorder *Recorder) terminalOption { + return func(t *Terminal) { + t.stdout = recorder.inputWriter + } +} + +func WithStderrRecorder(recorder *Recorder) terminalOption { + return func(t *Terminal) { + t.stderr = recorder.inputWriter + } +} diff --git a/term/terminal_test.go b/term/terminal_test.go new file mode 100644 index 00000000..7ded96a5 --- /dev/null +++ b/term/terminal_test.go @@ -0,0 +1,190 @@ +package term + +import ( + "bytes" + "io" + "log" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTerminalAskYesNo(t *testing.T) { + t.Parallel() + + fixtures := []struct { + input string + defaultAnswer bool + expected bool + }{ + // Empty answer => will follow the defaultAnswer + {"", true, true}, + {"", false, false}, + {"\n", true, true}, + {"\n", false, false}, + {"\r\n", true, true}, + {"\r\n", false, false}, + // Garbage answer => will always return false + {"aa", true, false}, + {"aa", false, false}, + {"aa\n", true, false}, + {"aa\n", false, false}, + {"aa\r\n", true, false}, + {"aa\r\n", false, false}, + // Answer yes + {"y", true, true}, + {"y", false, true}, + {"y\n", true, true}, + {"y\n", false, true}, + {"y\r\n", true, true}, + {"y\r\n", false, true}, + // Full answer yes + {"yes", true, true}, + {"yes", false, true}, + {"yes\n", true, true}, + {"yes\n", false, true}, + {"yes\r\n", true, true}, + {"yes\r\n", false, true}, + // Answer no + {"n", true, false}, + {"n", false, false}, + {"n\n", true, false}, + {"n\n", false, false}, + {"n\r\n", true, false}, + {"n\r\n", false, false}, + // Full answer no + {"no", true, false}, + {"no", false, false}, + {"no\n", true, false}, + {"no\n", false, false}, + {"no\r\n", true, false}, + {"no\r\n", false, false}, + } + for _, fixture := range fixtures { + output := new(bytes.Buffer) + terminal := NewTerminal(WithStdin(bytes.NewBufferString(fixture.input)), WithStdout(output)) + result := terminal.AskYesNo("message", fixture.defaultAnswer) + assert.Contains(t, output.String(), "message? (") + assert.Equalf(t, fixture.expected, result, "when input was %q", fixture.input) + } +} + +func ExamplePrint() { + terminal := NewTerminal() + _, err := terminal.Print("ExampleTerminalPrint") + if err != nil { + log.Fatal(err) + } + // Output: ExampleTerminalPrint +} + +func TestDefaultTerminal(t *testing.T) { + terminal := NewTerminal() + // in a test environment, stdout and stderr are not terminals, so we expect false for both + assert.False(t, terminal.StdoutIsTerminal()) + assert.False(t, terminal.StderrIsTerminal()) +} + +func TestReadPasswordFromBuffer(t *testing.T) { + input := "mysecretpassword\n" + terminal := NewTerminal(WithStdin(bytes.NewBufferString(input))) + password, err := terminal.ReadPassword() + assert.NoError(t, err) + assert.Equal(t, "mysecretpassword", password) +} + +func TestTerminalOutputCapture(t *testing.T) { + buffer := new(bytes.Buffer) + recorder, err := NewRecorder(buffer) + require.NoError(t, err) + + terminal := NewTerminal(WithStdoutRecorder(recorder)) + written, err := terminal.Print("TestTerminalOutputCapture") + require.NoError(t, err) + assert.Equal(t, 25, written) + + err = recorder.Close() + require.NoError(t, err) + assert.Equal(t, "TestTerminalOutputCapture", buffer.String()) +} + +// ansiText is a string with ANSI color codes that should be passed through or stripped depending on the writer. +const ansiText = "\x1b[31mhello\x1b[0m" +const plainText = "hello" + +func TestGetColorableWriterWithNonFileWriterReturnsNonColorable(t *testing.T) { + t.Parallel() + // A bytes.Buffer is not an *os.File, so getColorableWriter always wraps it as NonColorable. + buf := &bytes.Buffer{} + terminal := NewTerminal() + writer := terminal.getColorableWriter(buf) + + _, err := writer.Write([]byte(ansiText)) + require.NoError(t, err) + + // NonColorable strips ANSI escape codes. + assert.Equal(t, plainText, buf.String()) +} + +func TestGetColorableWriterWithColorsDisabledReturnsNonColorable(t *testing.T) { + t.Parallel() + // Even for an *os.File, explicitly disabling colors forces a NonColorable writer. + pr, pw, err := os.Pipe() + require.NoError(t, err) + defer pr.Close() + + terminal := NewTerminal(WithColors(false)) + writer := terminal.getColorableWriter(pw) + + _, err = writer.Write([]byte(ansiText)) + require.NoError(t, err) + pw.Close() + + out, err := io.ReadAll(pr) + require.NoError(t, err) + // NonColorable strips ANSI escape codes. + assert.Equal(t, plainText, string(out)) +} + +func TestGetColorableWriterWithColorsUndefinedOnNonTerminalFileReturnsNonColorable(t *testing.T) { + t.Parallel() + // With undefined colors and a non-terminal *os.File (e.g. a pipe), the writer is NonColorable. + pr, pw, err := os.Pipe() + require.NoError(t, err) + defer pr.Close() + + terminal := NewTerminal() // enableColors is undefined + writer := terminal.getColorableWriter(pw) + + _, err = writer.Write([]byte(ansiText)) + require.NoError(t, err) + pw.Close() + + out, err := io.ReadAll(pr) + require.NoError(t, err) + // NonColorable strips ANSI escape codes. + assert.Equal(t, plainText, string(out)) +} + +func TestGetColorableWriterWithColorsEnabledOnNonTerminalFileReturnsColorable(t *testing.T) { + t.Parallel() + // With colors explicitly enabled, getColorableWriter returns a Colorable writer even for a + // non-terminal *os.File, so ANSI codes are passed through unchanged. + pr, pw, err := os.Pipe() + require.NoError(t, err) + defer pr.Close() + + terminal := NewTerminal(WithColors(true)) + writer := terminal.getColorableWriter(pw) + + _, err = writer.Write([]byte(ansiText)) + require.NoError(t, err) + pw.Close() + + out, err := io.ReadAll(pr) + require.NoError(t, err) + // Colorable preserves ANSI escape codes. + assert.Equal(t, ansiText, string(out)) +} diff --git a/term/test/main.go b/term/test/main.go new file mode 100644 index 00000000..20d3028a --- /dev/null +++ b/term/test/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "github.com/creativeprojects/resticprofile/term" + "github.com/creativeprojects/resticprofile/util/ansi" +) + +func main() { + terminal := term.NewTerminal() + terminal.Println(ansi.Bold("colorable terminal")) + terminal.Printf("stdout is terminal: %v, stderr is terminal: %v\n", terminal.StdoutIsTerminal(), terminal.StderrIsTerminal()) + + terminal = term.NewTerminal(term.WithColors(false)) + terminal.Println(ansi.Bold("non colorable terminal")) + terminal.Printf("stdout is terminal: %v, stderr is terminal: %v\n", terminal.StdoutIsTerminal(), terminal.StderrIsTerminal()) +} diff --git a/update.go b/update.go index ed89a0c6..2226fab5 100644 --- a/update.go +++ b/update.go @@ -6,7 +6,6 @@ import ( "context" "errors" "fmt" - "io" "os" "runtime" @@ -32,19 +31,19 @@ func init() { config.ExcludeProfileSection(def.name) } -func selfUpdate(_ io.Writer, ctx commandContext) error { +func selfUpdate(ctx commandContext) error { quiet := ctx.flags.quiet if !quiet && len(ctx.request.arguments) > 0 && (ctx.request.arguments[0] == "-q" || ctx.request.arguments[0] == "--quiet") { quiet = true } - err := confirmAndSelfUpdate(quiet, ctx.flags.verbose, version, true) + err := confirmAndSelfUpdate(ctx.terminal, quiet, ctx.flags.verbose, version, true) if err != nil { return err } return nil } -func confirmAndSelfUpdate(quiet, debug bool, version string, prerelease bool) error { +func confirmAndSelfUpdate(terminal *term.Terminal, quiet, debug bool, version string, prerelease bool) error { if debug { selfupdate.SetLogger(clog.NewStandardLogger(clog.LevelDebug, clog.GetDefaultLogger())) } @@ -70,8 +69,8 @@ func confirmAndSelfUpdate(quiet, debug bool, version string, prerelease bool) er } // don't ask in quiet mode - if !quiet && !term.AskYesNo(os.Stdin, fmt.Sprintf("Do you want to update to version %s", latest.Version()), true) { - term.Println("Never mind") + if !quiet && !terminal.AskYesNo(fmt.Sprintf("Do you want to update to version %s", latest.Version()), true) { + terminal.Println("Never mind") return nil } diff --git a/update_test.go b/update_test.go index ddf21e93..a17d097c 100644 --- a/update_test.go +++ b/update_test.go @@ -7,6 +7,7 @@ import ( "github.com/creativeprojects/clog" "github.com/creativeprojects/go-selfupdate" + "github.com/creativeprojects/resticprofile/term" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,7 +20,7 @@ func TestUpdate(t *testing.T) { clog.SetTestLog(t) defer clog.CloseTestLog() - err := confirmAndSelfUpdate(true, true, "0.0.1", false) + err := confirmAndSelfUpdate(term.NewTerminal(), true, true, "0.0.1", false) require.ErrorIsf(t, err, selfupdate.ErrExecutableNotFoundInArchive, "error returned isn't wrapping %q but is instead: %q", selfupdate.ErrExecutableNotFoundInArchive, err) assert.Contains(t, err.Error(), "resticprofile.test") } diff --git a/util/ansi/colors.go b/util/ansi/colors.go new file mode 100644 index 00000000..748afcc6 --- /dev/null +++ b/util/ansi/colors.go @@ -0,0 +1,39 @@ +package ansi + +import ( + "strings" + + "github.com/fatih/color" +) + +var ( + bold = color.New(color.Bold) + Bold = bold.SprintFunc() + cyan = color.New(color.FgCyan) + Cyan = cyan.SprintFunc() + gray = New256FgColor(243) + Gray = gray.SprintFunc() + green = color.New(color.FgGreen) + Green = green.SprintFunc() + yellow = color.New(color.FgYellow) + Yellow = yellow.SprintFunc() + underline = color.New(color.Underline) + Underline = underline.Sprint() +) + +// New256FgColor return a new xterm 256 (8bit) foreground color +func New256FgColor(code uint8) *color.Color { + return color.New(38, 5, color.Attribute(code)) +} + +// New256BgColor return a new xterm 256 (8bit) background color +func New256BgColor(code uint8) *color.Color { + return color.New(48, 5, color.Attribute(code)) +} + +func ColorSequence(fn func(a ...any) string) (start, stop string) { + s := fn("||") + before, after, _ := strings.Cut(s, "||") + start, stop = before, after + return +} diff --git a/util/ansi/colors_test.go b/util/ansi/colors_test.go new file mode 100644 index 00000000..03513d57 --- /dev/null +++ b/util/ansi/colors_test.go @@ -0,0 +1,19 @@ +package ansi + +import ( + "strings" + "testing" + + "github.com/fatih/color" + "github.com/stretchr/testify/assert" +) + +func Test256Colors(t *testing.T) { + test := func(color *color.Color, expected string) { + color.EnableColor() + seq := strings.Split(color.Sprint("||"), "||")[0] + assert.Equal(t, expected, seq) + } + test(New256FgColor(10), Sequence('m', 38, 5, 10)) + test(New256BgColor(10), Sequence('m', 48, 5, 10)) +} diff --git a/util/ansi/escape.go b/util/ansi/escape.go new file mode 100644 index 00000000..5ca14831 --- /dev/null +++ b/util/ansi/escape.go @@ -0,0 +1,36 @@ +package ansi + +import ( + "fmt" + "strings" +) + +const ( + EscapeByte = 0x1b + Escape = "\x1b" + ClearLine = Escape + "[2K" + Reset = Escape + "[0m" +) + +// Sequence builds an arbitrary escape sequence +func Sequence(terminator byte, attributes ...any) string { + seq := strings.Builder{} + seq.WriteString(Escape) + seq.WriteByte('[') + for i, attribute := range attributes { + if i > 0 { + seq.WriteByte(';') + } + _, _ = fmt.Fprint(&seq, attribute) + } + seq.WriteByte(terminator) + return seq.String() +} + +// CursorUpLeftN creates the escape sequence to move the cursor left and up N lines +func CursorUpLeftN(lines int) string { + if lines < 0 { + lines = 0 + } + return Sequence('F', lines) +} diff --git a/util/ansi/escape_test.go b/util/ansi/escape_test.go new file mode 100644 index 00000000..4b25be8f --- /dev/null +++ b/util/ansi/escape_test.go @@ -0,0 +1,14 @@ +package ansi + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCursorUpLeftN(t *testing.T) { + assert.Equal(t, Escape+"[0F", CursorUpLeftN(-1)) + assert.Equal(t, Escape+"[0F", CursorUpLeftN(0)) + assert.Equal(t, Escape+"[1F", CursorUpLeftN(1)) + assert.Equal(t, Escape+"[2F", CursorUpLeftN(2)) +} diff --git a/util/ansi/linewriter.go b/util/ansi/linewriter.go new file mode 100644 index 00000000..b9ed0297 --- /dev/null +++ b/util/ansi/linewriter.go @@ -0,0 +1,110 @@ +package ansi + +import ( + "io" + "unicode/utf8" +) + +type lineLengthWriter struct { + writer io.Writer + tokens []byte + maxLineLength, lastWhite, breakLength, lineLength int + invisibleLength, lastWhiteInvisibleLength int + inAnsi bool +} + +// NewLineLengthWriter return an io.Writer that limits the max line length, adding line breaks ('\n') as needed. +// The writer detects the right most column (consecutive whitespace) and aligns content if possible. +// UTF sequences are counted as single character and ANSI escape sequences are not counted at all. +func NewLineLengthWriter(writer io.Writer, maxLineLength int) io.Writer { + return &lineLengthWriter{ + tokens: []byte{' ', '\n'}, + writer: writer, + maxLineLength: maxLineLength, + } +} + +func (l *lineLengthWriter) visibleLineLength() int { return l.lineLength - l.invisibleLength } + +func (l *lineLengthWriter) Write(p []byte) (n int, err error) { + var written int + offset := l.lineLength + + for i := 0; i < len(p); i++ { + l.lineLength++ + ws := p[i] == l.tokens[0] // whitespace + br := p[i] == l.tokens[1] // linebreak + + // don't count ansi control sequences + if l.inAnsi = l.inAnsi || p[i] == EscapeByte; l.inAnsi { + terminator := (p[i] >= 'a' && p[i] <= 'z') || (p[i] >= 'A' && p[i] <= 'Z') + l.inAnsi = !terminator + l.invisibleLength++ + continue + } + + // count UTF sequence as one character + if p[i] >= utf8.RuneSelf { + if !utf8.RuneStart(p[i]) { + l.invisibleLength++ + } + continue + } + + if !br && l.visibleLineLength() > l.maxLineLength && l.lastWhite-offset > 0 { + lastWhiteIndex := l.lastWhite - offset - 1 + remainder := i - lastWhiteIndex + + if written, err = l.writer.Write(p[:lastWhiteIndex]); err == nil { + p = p[lastWhiteIndex+1:] + i = remainder - 1 + n += written + 1 + + _, _ = l.writer.Write(l.tokens[1:]) // write break (instead of WS at lastWhiteIndex) + for range l.breakLength { + _, _ = l.writer.Write(l.tokens[0:1]) // fill spaces for alignment + } + + l.lineLength = l.breakLength + remainder + l.lastWhite = l.breakLength + offset = l.breakLength + + l.invisibleLength -= l.lastWhiteInvisibleLength + l.lastWhiteInvisibleLength = 0 + } else { + return + } + } + + if ws { + if l.lastWhite == l.lineLength-1 && l.visibleLineLength() < l.maxLineLength*2/3 { + l.breakLength = l.visibleLineLength() + } + l.lastWhite = l.lineLength + l.lastWhiteInvisibleLength = l.invisibleLength + + } else if br { + if written, err = l.writer.Write(p[:i+1]); err == nil { + p = p[i+1:] + i = -1 + n += written + + l.lineLength = 0 + l.lastWhite = 0 + l.breakLength = 0 + offset = 0 + + l.invisibleLength = 0 + l.lastWhiteInvisibleLength = 0 + } else { + return + } + } + } + + // write remainder + if written, err = l.writer.Write(p); err == nil { + n += written + } + return +} diff --git a/util/ansi/linewriter_test.go b/util/ansi/linewriter_test.go new file mode 100644 index 00000000..62146d4e --- /dev/null +++ b/util/ansi/linewriter_test.go @@ -0,0 +1,280 @@ +package ansi + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/fatih/color" + "github.com/stretchr/testify/assert" +) + +var ansiColor = func() (c *color.Color) { + c = color.New(color.FgCyan) + c.EnableColor() + return +}() + +var colored = ansiColor.SprintFunc() + +func TestLineLengthWriter(t *testing.T) { + tests := []struct { + input, expected string + chunks, scale int + }{ + // test non-breakable + {input: strings.Repeat("-", 50), expected: strings.Repeat("-", 50), chunks: 15}, + + // test breakable without columns + { + input: strings.Repeat("word ", 20), + expected: "" + + strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + + strings.TrimSpace(strings.Repeat("word ", 8)) + "\n" + + strings.Repeat("word ", 4), + chunks: 5, scale: 6, + }, + + // test breakable with ANSI color + { + input: strings.Repeat(colored("word "), 20), + expected: "" + + strings.Repeat(colored("word "), 7) + colored("word\n") + + strings.Repeat(colored("word "), 7) + colored("word\n") + + strings.Repeat(colored("word "), 4), + }, + + // test breakable with 2 columns + { + input: "word word word word " + + strings.Repeat("word ", 20), + expected: "" + + "word word word word " + + "word word word\n" + + strings.Repeat(" word word word\n", 5) + + " word word ", + chunks: 3, scale: 15, + }, + + // test breakable with 2 columns and ANSI color + { + input: colored("word word word word ") + + strings.Repeat(colored("word "), 20), + expected: "" + + colored("word word word word ") + + colored("word ") + colored("word ") + colored("word\n ") + + strings.Repeat(colored("word ")+colored("word ")+colored("word\n "), 5) + + colored("word ") + colored("word "), + }, + + // test breakable with 2 columns and unicode character + { + input: "w😁rd wo😁d wor😁 😁ord 😁😁😁😁 w😁rd wo😁d wor😁 😁ord 😁😁😁😁", + expected: "w😁rd wo😁d wor😁 😁ord 😁😁😁😁 w😁rd wo😁d \n" + + " wor😁 😁ord 😁😁😁😁", + }, + { + input: "word word word word word word word word word word", + expected: "word word word word word word word \n" + + " word word word", + }, + + // test breakable with 2 columns, colors and unicode character + { + input: colored("w😁rd wo😁d wor😁 😁ord ") + + strings.Repeat(colored("wor😁 ")+colored("😁ord "), 2), + expected: "" + + colored("w😁rd wo😁d wor😁 😁ord ") + + colored("wor😁 ") + colored("😁ord ") + colored("wor😁\n ") + + colored("😁ord "), + }, + + // test real-world content + { + input: ` +Usage of resticprofile: + resticprofile [resticprofile flags] [profile name.][restic command] [restic flags] + resticprofile [resticprofile flags] [profile name.][resticprofile command] [command specific flags] + +resticprofile flags: + -c, --config string configuration file (default "profiles") + --dry-run display the restic commands instead of running them + -f, --format string file format of the configuration (default is to use the file extension) + -h, --help display this help + --lock-wait duration wait up to duration to acquire a lock (syntax "1h5m30s") + -l, --log string logs to a target instead of the console + -n, --name string profile name (default "default") + --no-ansi disable ansi control characters (disable console colouring) + --no-lock skip profile lock file + --no-prio don't set any priority on load: used when started from a service that has already set the priority + -q, --quiet display only warnings and errors + --theme string console colouring theme (dark, light, none) (default "light") + --trace display even more debugging information + -v, --verbose display some debugging information + -w, --wait wait at the end until the user presses the enter key + + +resticprofile own commands: + help display help (run in verbose mode for detailed information) + version display version (run in verbose mode for detailed information) + self-update update to latest resticprofile (use -q/--quiet flag to update without confirmation) + profiles display profile names from the configuration file + show show all the details of the current profile + schedule schedule jobs from a profile (use --all flag to schedule all jobs of all profiles) + unschedule remove scheduled jobs of a profile (use --all flag to unschedule all profiles) + status display the status of scheduled jobs (use --all flag for all profiles) + generate generate resources (--random-key [size], --bash-completion & --zsh-completion) + +Documentation available at https://creativeprojects.github.io/resticprofile/ +`, + expected: ` +Usage of resticprofile: + resticprofile [resticprofile flags] + [profile name.][restic command] + [restic flags] + resticprofile [resticprofile flags] + [profile name.][resticprofile + command] [command specific flags] + +resticprofile flags: + -c, --config string + configuration + file (default + "profiles") + --dry-run display + the restic + commands + instead of + running them + -f, --format string file + format of the + configuration + (default is to + use the file + extension) + -h, --help display + this help + --lock-wait duration wait up to + duration to acquire a lock + (syntax "1h5m30s") + -l, --log string logs to a + target instead + of the console + -n, --name string profile + name (default + "default") + --no-ansi disable + ansi control + characters + (disable + console + colouring) + --no-lock skip + profile lock + file + --no-prio don't set + any priority + on load: used + when started + from a service + that has + already set + the priority + -q, --quiet display + only warnings + and errors + --theme string console + colouring + theme (dark, + light, none) + (default + "light") + --trace display + even more + debugging + information + -v, --verbose display + some debugging + information + -w, --wait wait at + the end until + the user + presses the + enter key + + +resticprofile own commands: + help display help (run in + verbose mode for + detailed information) + version display version (run + in verbose mode for + detailed information) + self-update update to latest + resticprofile (use + -q/--quiet flag to + update without + confirmation) + profiles display profile names + from the configuration + file + show show all the details + of the current profile + schedule schedule jobs from a + profile (use --all + flag to schedule all + jobs of all profiles) + unschedule remove scheduled jobs + of a profile (use + --all flag to + unschedule all + profiles) + status display the status of + scheduled jobs (use + --all flag for all + profiles) + generate generate resources + (--random-key [size], + --bash-completion & + --zsh-completion) + +Documentation available at +https://creativeprojects.github.io/resticprofile/ +`, + }, + } + + for i, test := range tests { + if test.scale == 0 { + test.scale = 1 + } + for chunkSize := 0; chunkSize <= test.chunks; chunkSize++ { + t.Run(fmt.Sprintf("%d-%d", i, chunkSize), func(t *testing.T) { + buffer := bytes.Buffer{} + writer := NewLineLengthWriter(&buffer, 40) + input := []byte(test.input) + + var ( + n int + err error + ) + switch chunkSize { + case 0: + n, err = writer.Write(input) + assert.Equal(t, len(input), n) + default: + for len(input) > 0 && err == nil { + length := min(test.scale*chunkSize, len(input)) + n, err = writer.Write(input[:length]) + assert.Equal(t, length, n) + input = input[length:] + } + } + + assert.Nil(t, err) + assert.Equal(t, test.expected, buffer.String()) + }) + } + } +} diff --git a/util/ansi/runes.go b/util/ansi/runes.go new file mode 100644 index 00000000..da578842 --- /dev/null +++ b/util/ansi/runes.go @@ -0,0 +1,26 @@ +package ansi + +// RunesLength returns the visible content length (and index at the length) +func RunesLength(src []rune, maxLength int) (length, index int) { + esc := rune(Escape[0]) + inEsc := false + index = -1 + for i, r := range src { + if r == esc { + inEsc = true + } else if inEsc { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + inEsc = false + } + } else { + if length == maxLength { + index = i + } + length++ + } + } + if index < 0 { + index = len(src) + } + return +} diff --git a/util/ansi/runes_test.go b/util/ansi/runes_test.go new file mode 100644 index 00000000..c00452b2 --- /dev/null +++ b/util/ansi/runes_test.go @@ -0,0 +1,40 @@ +package ansi + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunesLength(t *testing.T) { + tests := []struct { + input []rune + max, index, length int + }{ + {input: []rune{}, index: 0, length: 0, max: -1}, + {input: []rune(""), index: 0, length: 0, max: -1}, + {input: []rune(ClearLine + ""), index: len(ClearLine), length: 0, max: -1}, + {input: []rune(ClearLine + "" + ClearLine), index: 2 * len(ClearLine), length: 0, max: -1}, + {input: []rune(ClearLine + "◷" + ClearLine), index: 1 + 2*len(ClearLine), length: 1, max: -1}, + + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: len(ClearLine), length: 3, max: 0}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: 1 + len(ClearLine), length: 3, max: 1}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: 2 + len(ClearLine), length: 3, max: 2}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine), index: 3 + 2*len(ClearLine), length: 3, max: 3}, + + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 3 + 2*len(ClearLine), length: 6, max: 3}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 4 + 2*len(ClearLine), length: 6, max: 4}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 5 + 2*len(ClearLine), length: 6, max: 5}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 6 + 2*len(ClearLine), length: 6, max: 6}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 6 + 2*len(ClearLine), length: 6, max: 7}, + {input: []rune(ClearLine + "◷◷◷" + ClearLine + "123"), index: 6 + 2*len(ClearLine), length: 6, max: -1}, + } + for idx, test := range tests { + t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) { + length, index := RunesLength(test.input, test.max) + assert.Equal(t, test.length, length, "length") + assert.Equal(t, test.index, index, "index") + }) + } +} diff --git a/util/filewriter.go b/util/filewriter.go new file mode 100644 index 00000000..1e7e11a6 --- /dev/null +++ b/util/filewriter.go @@ -0,0 +1,229 @@ +package util + +import ( + "fmt" + "io" + "os" + "sync" + "time" + + "github.com/creativeprojects/resticprofile/platform" +) + +// AsyncFileWriterAppendFunc is called for every input byte when appending it to the output buffer (is similar to buf = append(buf, byte)) +type AsyncFileWriterAppendFunc func(dst []byte, c byte) []byte + +type asyncFileWriterOption func(writer *asyncFileWriter) + +// WithAsyncWriteInterval sets the interval at which writes happen at least when data is pending +func WithAsyncWriteInterval(duration time.Duration) asyncFileWriterOption { + return func(writer *asyncFileWriter) { writer.interval = duration } +} + +// WithAsyncFileKeepOpen toggles whether the file is kept open between writes. Defaults to true for all OS except Windows. +func WithAsyncFileKeepOpen(keepOpen bool) asyncFileWriterOption { + return func(writer *asyncFileWriter) { writer.keepOpen = keepOpen } +} + +// WithAsyncFileAppendFunc sets AsyncFileWriterAppendFunc. Default is to not use a custom appender. +func WithAsyncFileAppendFunc(appender AsyncFileWriterAppendFunc) asyncFileWriterOption { + return func(writer *asyncFileWriter) { writer.appender = appender } +} + +// WithAsyncFilePerm sets file perms to apply when creating the file +func WithAsyncFilePerm(perm os.FileMode) asyncFileWriterOption { + return func(writer *asyncFileWriter) { writer.perm = perm } +} + +// WithAsyncFileFlag sets file open flags +func WithAsyncFileFlag(flag int) asyncFileWriterOption { + return func(writer *asyncFileWriter) { writer.flag = flag } +} + +// WithAsyncFileTruncate enables that existing files are truncated +func WithAsyncFileTruncate() asyncFileWriterOption { + return func(writer *asyncFileWriter) { writer.flag |= os.O_TRUNC } +} + +const ( + asyncWriterDataChanSize = 64 + asyncWriterBlockSize = 4 * 1024 + asyncWriterMaxBlockSize = asyncWriterDataChanSize * asyncWriterBlockSize +) + +var ( + asyncWriterBufferPool = sync.Pool{ + New: func() any { return make([]byte, asyncWriterBlockSize) }, + } +) + +func asyncWriterReturnToPool(data []byte) { + if cap(data) == asyncWriterBlockSize { + asyncWriterBufferPool.Put(data[:0]) + } +} + +// NewAsyncFileWriter creates a file writer that accumulates Write requests and writes them at a fixed rate (every 250 ms by default) +func NewAsyncFileWriter(filename string, options ...asyncFileWriterOption) (io.WriteCloser, error) { + w := &asyncFileWriter{ + flush: make(chan chan error), + done: make(chan chan error), + data: make(chan []byte, asyncWriterDataChanSize), + perm: 0644, + flag: os.O_WRONLY | os.O_APPEND | os.O_CREATE, + interval: 250 * time.Millisecond, + keepOpen: !platform.IsWindows(), + } + for _, option := range options { + option(w) + } + + var ( + buffer []byte + lastError error + file *os.File + ) + + closeFile := func() { + if file != nil { + lastError = file.Close() + file = nil + } + } + + flush := func(alsoEmpty, whenTooBig bool) { + if len(buffer) == 0 && !alsoEmpty { + return + } + if len(buffer) < asyncWriterMaxBlockSize && whenTooBig { + return + } + if file == nil { + file, lastError = os.OpenFile(filename, w.flag, w.perm) + } + if file != nil { + var written int + written, lastError = file.Write(buffer) + if remaining := len(buffer) - written; remaining > 0 { + copy(buffer, buffer[written:]) + buffer = buffer[:remaining] + } else { + buffer = buffer[:0] + } + } + if w.keepOpen { + _ = file.Sync() + } else { + closeFile() + } + } + + // test if we can create the file + buffer = make([]byte, 0, asyncWriterBlockSize) + flush(true, false) + + // data appending + addToBuffer := func(data []byte) { + buffer = append(buffer, data...) // fast path + asyncWriterReturnToPool(data) + flush(false, true) + } + if w.appender != nil { + addToBuffer = func(data []byte) { + for _, c := range data { + buffer = w.appender(buffer, c) + } + asyncWriterReturnToPool(data) + flush(false, true) + } + } + + addPendingData := func(maxCount int) { + for ; maxCount > 0; maxCount-- { + select { + case data, ok := <-w.data: + if ok { + addToBuffer(data) + } else { + return // closed + } + default: + return // no-more-data + } + } + } + + // data transport + go func() { + ticker := time.NewTicker(w.interval) + defer ticker.Stop() + defer closeFile() + + for { + select { + case data := <-w.data: + addToBuffer(data) + case <-ticker.C: + flush(false, false) + case req := <-w.flush: + addPendingData(1024) + flush(false, false) + req <- lastError + case req := <-w.done: + defer func(response chan error) { + response <- lastError + }(req) + close(w.done) + close(w.flush) + close(w.data) + addPendingData(1024) + flush(false, false) + closeFile() + return + } + } + }() + + return w, lastError +} + +type asyncFileWriter struct { + done, flush chan chan error + data chan []byte + appender AsyncFileWriterAppendFunc + flag int + perm os.FileMode + keepOpen bool + interval time.Duration +} + +func (w *asyncFileWriter) Close() error { + req := make(chan error) + w.done <- req + return <-req +} + +func (w *asyncFileWriter) Flush() error { + req := make(chan error) + w.flush <- req + return <-req +} + +func (w *asyncFileWriter) Write(data []byte) (n int, err error) { + defer func() { + msg := recover() + if msg != nil { + err = fmt.Errorf("panic: %v", msg) + } + }() + var buffer []byte + if len(data) <= asyncWriterBlockSize { + buffer = asyncWriterBufferPool.Get().([]byte)[:len(data)] + } else { + buffer = make([]byte, len(data)) + } + + n = copy(buffer, data) + w.data <- buffer + return +} diff --git a/util/filewriter_test.go b/util/filewriter_test.go new file mode 100644 index 00000000..3c4cf690 --- /dev/null +++ b/util/filewriter_test.go @@ -0,0 +1,258 @@ +package util + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAsyncFileWriterBasicWrite(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + + n, err := w.Write([]byte("hello world")) + assert.NoError(t, err) + assert.Equal(t, 11, n) + + err = w.Close() + assert.NoError(t, err) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) +} + +func TestAsyncFileWriterMultipleWrites(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + + for _, chunk := range []string{"foo", "bar", "baz"} { + _, err = w.Write([]byte(chunk)) + require.NoError(t, err) + } + + err = w.Close() + assert.NoError(t, err) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "foobarbaz", string(content)) +} + +func TestAsyncFileWriterFlush(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename, + WithAsyncWriteInterval(10*time.Second), // long interval so flush drives the write + ) + require.NoError(t, err) + defer w.Close() + + _, err = w.Write([]byte("flushed")) + require.NoError(t, err) + + fw, ok := w.(*asyncFileWriter) + require.True(t, ok) + err = fw.Flush() + assert.NoError(t, err) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "flushed", string(content)) +} + +func TestAsyncFileWriterTruncate(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + // First write + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + _, err = w.Write([]byte("original")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + // Second write with truncate + w, err = NewAsyncFileWriter(filename, WithAsyncFileTruncate()) + require.NoError(t, err) + _, err = w.Write([]byte("new")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "new", string(content)) +} + +func TestAsyncFileWriterAppendMode(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + // First write + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + _, err = w.Write([]byte("first")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + // Second write appends by default + w, err = NewAsyncFileWriter(filename) + require.NoError(t, err) + _, err = w.Write([]byte("second")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "firstsecond", string(content)) +} + +func TestAsyncFileWriterFilePerm(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename, WithAsyncFilePerm(0600)) + require.NoError(t, err) + _, err = w.Write([]byte("data")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + info, err := os.Stat(filename) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0600), info.Mode().Perm()) +} + +func TestAsyncFileWriterCustomAppendFunc(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + // appender that uppercases every byte + upperAppender := func(dst []byte, c byte) []byte { + if c >= 'a' && c <= 'z' { + c -= 32 + } + return append(dst, c) + } + + w, err := NewAsyncFileWriter(filename, WithAsyncFileAppendFunc(upperAppender)) + require.NoError(t, err) + _, err = w.Write([]byte("hello")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "HELLO", string(content)) +} + +func TestAsyncFileWriterKeepOpenFalse(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename, WithAsyncFileKeepOpen(false)) + require.NoError(t, err) + _, err = w.Write([]byte("keep-closed")) + require.NoError(t, err) + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "keep-closed", string(content)) +} + +func TestAsyncFileWriterInvalidPath(t *testing.T) { + _, err := NewAsyncFileWriter("/nonexistent/path/test.log") + assert.Error(t, err) +} + +func TestAsyncFileWriterLargeWrite(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + // Write more than asyncWriterBlockSize (4KB) + large := strings.Repeat("x", asyncWriterBlockSize*3) + + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + n, err := w.Write([]byte(large)) + assert.NoError(t, err) + assert.Equal(t, len(large), n) + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, large, string(content)) +} + +func TestAsyncFileWriterWriteInterval(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename, WithAsyncWriteInterval(10*time.Millisecond)) + require.NoError(t, err) + defer w.Close() + + _, err = w.Write([]byte("interval")) + require.NoError(t, err) + + // Wait enough time for the ticker to fire + time.Sleep(50 * time.Millisecond) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "interval", string(content)) +} + +func TestAsyncFileWriterWriteEmptyData(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + + n, err := w.Write([]byte{}) + assert.NoError(t, err) + assert.Equal(t, 0, n) + + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Empty(t, content) +} + +func TestAsyncFileWriterBufferReuse(t *testing.T) { + // Verify pooled-size writes don't corrupt data due to buffer reuse + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + + var expected bytes.Buffer + for i := range 10 { + chunk := []byte(strings.Repeat(string(rune('a'+i)), asyncWriterBlockSize)) + _, err = w.Write(chunk) + require.NoError(t, err) + expected.Write(chunk) + } + + require.NoError(t, w.Close()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, expected.String(), string(content)) +} diff --git a/util/reader.go b/util/reader.go new file mode 100644 index 00000000..8a969c62 --- /dev/null +++ b/util/reader.go @@ -0,0 +1,94 @@ +package util + +import ( + "io" + "sync" +) + +// ReadFunc is the callback in NewFilterReader & NewFilterReadCloser for io.Reader.Read. +type ReadFunc func(bytes []byte) (n int, err error) + +type filterReader struct { + read ReadFunc +} + +func (f *filterReader) Read(bytes []byte) (n int, err error) { + if f.read == nil { + err = io.EOF + return + } + return f.read(bytes) +} + +// NewFilterReader creates a new io.Reader redirects all read calls to ReadFunc +func NewFilterReader(read ReadFunc) io.Reader { + return &filterReader{read} +} + +// CloseFunc is the callback in NewFilterReadCloser for io.Closer.Close. It is guaranteed to be called once. +type CloseFunc func() (err error) + +type filterReadCloser struct { + filterReader + close CloseFunc +} + +func (c filterReadCloser) Close() error { + defer func() { + c.close = nil + c.read = nil + }() + if c.close == nil { + return nil + } + return c.close() +} + +// NewFilterReadCloser creates a new io.ReadCloser redirects all calls to ReadFunc and CloseFunc +func NewFilterReadCloser(read ReadFunc, closer CloseFunc) io.ReadCloser { + return &filterReadCloser{*NewFilterReader(read).(*filterReader), closer} +} + +// NewSyncReader creates a new reader that is safe for concurrent use +func NewSyncReader[R io.Reader](reader R) SyncReader[R] { + mutex := new(sync.Mutex) + return NewSyncReaderMutex(reader, mutex, mutex) +} + +// NewSyncReaderMutex creates a new reader that is safe for concurrent use and synced with the specified sync.Mutex +func NewSyncReaderMutex[R io.Reader](reader R, mutex, closeMutex *sync.Mutex) SyncReader[R] { + return &syncReader[R]{reader: reader, mutex: mutex, closeMutex: closeMutex} +} + +// SyncReader implements an io.ReadCloser that is safe for concurrent use +type SyncReader[R io.Reader] interface { + io.ReadCloser + // Locked provides sync.Mutex locked access to the underlying reader + Locked(fn func(R) error) error +} + +type syncReader[R io.Reader] struct { + reader io.Reader + mutex, closeMutex *sync.Mutex +} + +func (w *syncReader[R]) Locked(fn func(reader R) error) error { + w.mutex.Lock() + defer w.mutex.Unlock() + return fn(w.reader.(R)) +} + +func (w *syncReader[R]) Read(p []byte) (n int, err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + return w.reader.Read(p) +} + +func (w *syncReader[R]) Close() (err error) { + w.closeMutex.Lock() + defer w.closeMutex.Unlock() + if f, ok := w.reader.(io.Closer); ok { + err = f.Close() + } + return +} diff --git a/util/reader_test.go b/util/reader_test.go new file mode 100644 index 00000000..39f98ffc --- /dev/null +++ b/util/reader_test.go @@ -0,0 +1,293 @@ +package util + +import ( + "errors" + "io" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// readCloser wraps a reader and tracks whether Close was called. +type trackingReadCloser struct { + io.Reader + closed bool + err error +} + +func (t *trackingReadCloser) Close() error { + t.closed = true + return t.err +} + +// filterReader & filterReadCloser tests + +func TestFilterReaderRead(t *testing.T) { + inner := strings.NewReader("hello") + r := NewFilterReader(inner.Read) + + buf := make([]byte, 5) + n, err := r.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(buf)) +} + +func TestFilterReaderReadEOF(t *testing.T) { + inner := strings.NewReader("") + r := NewFilterReader(inner.Read) + + buf := make([]byte, 4) + _, err := r.Read(buf) + assert.Equal(t, io.EOF, err) +} + +func TestFilterReaderNilReadFunc(t *testing.T) { + r := NewFilterReader(nil) + + buf := make([]byte, 4) + n, err := r.Read(buf) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, n) +} + +func TestFilterReaderReadAll(t *testing.T) { + inner := strings.NewReader("read all") + r := NewFilterReader(inner.Read) + + data, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, "read all", string(data)) +} + +func TestFilterReaderPropagatesError(t *testing.T) { + sentinel := errors.New("read error") + r := NewFilterReader(func([]byte) (int, error) { return 0, sentinel }) + + buf := make([]byte, 4) + _, err := r.Read(buf) + assert.Equal(t, sentinel, err) +} + +func TestFilterReadCloserRead(t *testing.T) { + inner := strings.NewReader("world") + rc := NewFilterReadCloser(inner.Read, func() error { return nil }) + + buf := make([]byte, 5) + n, err := rc.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "world", string(buf)) +} + +func TestFilterReadCloserClose(t *testing.T) { + closed := false + inner := strings.NewReader("") + rc := NewFilterReadCloser(inner.Read, func() error { + closed = true + return nil + }) + + err := rc.Close() + assert.NoError(t, err) + assert.True(t, closed) +} + +func TestFilterReadCloserCloseCalledTwice(t *testing.T) { + // Close() uses a value receiver so the nil-guard mutation does not persist; + // calling Close twice invokes the close func both times. + calls := 0 + inner := strings.NewReader("") + rc := NewFilterReadCloser(inner.Read, func() error { + calls++ + return nil + }) + + _ = rc.Close() + _ = rc.Close() + assert.Equal(t, 2, calls) +} + +func TestFilterReadCloserCloseError(t *testing.T) { + sentinel := errors.New("close error") + inner := strings.NewReader("") + rc := NewFilterReadCloser(inner.Read, func() error { return sentinel }) + + err := rc.Close() + assert.Equal(t, sentinel, err) +} + +func TestFilterReadCloserNilReadFunc(t *testing.T) { + rc := NewFilterReadCloser(nil, func() error { return nil }) + + buf := make([]byte, 4) + n, err := rc.Read(buf) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, n) +} + +func TestFilterReadCloserNilCloseFunc(t *testing.T) { + inner := strings.NewReader("") + rc := NewFilterReadCloser(inner.Read, nil) + + err := rc.Close() + assert.NoError(t, err) +} + +func TestFilterReadCloserReadAfterClose(t *testing.T) { + // Close() uses a value receiver so read is NOT cleared on the original; + // reads after close continue to work via the underlying reader. + inner := strings.NewReader("data") + rc := NewFilterReadCloser(inner.Read, func() error { return nil }) + + require.NoError(t, rc.Close()) + + buf := make([]byte, 4) + n, err := rc.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "data", string(buf[:n])) +} + +// syncReader tests + +func TestNewSyncReaderRead(t *testing.T) { + inner := strings.NewReader("hello world") + sr := NewSyncReader[io.Reader](inner) + + buf := make([]byte, 11) + n, err := sr.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 11, n) + assert.Equal(t, "hello world", string(buf)) +} + +func TestNewSyncReaderReadEOF(t *testing.T) { + inner := strings.NewReader("") + sr := NewSyncReader[io.Reader](inner) + + buf := make([]byte, 8) + _, err := sr.Read(buf) + assert.Equal(t, io.EOF, err) +} + +func TestNewSyncReaderCloseWithCloser(t *testing.T) { + rc := &trackingReadCloser{Reader: strings.NewReader("data")} + sr := NewSyncReader(rc) + + err := sr.Close() + assert.NoError(t, err) + assert.True(t, rc.closed) +} + +func TestNewSyncReaderCloseNonCloser(t *testing.T) { + // strings.Reader does not implement io.Closer; Close should be a no-op + inner := strings.NewReader("data") + sr := NewSyncReader[io.Reader](inner) + + err := sr.Close() + assert.NoError(t, err) +} + +func TestNewSyncReaderCloseError(t *testing.T) { + closeErr := errors.New("close failed") + rc := &trackingReadCloser{Reader: strings.NewReader("data"), err: closeErr} + sr := NewSyncReader(rc) + + err := sr.Close() + assert.Equal(t, closeErr, err) +} + +func TestNewSyncReaderLocked(t *testing.T) { + inner := strings.NewReader("locked") + sr := NewSyncReader(inner) + + var seen string + err := sr.Locked(func(r *strings.Reader) error { + buf := make([]byte, 6) + n, err := r.Read(buf) + seen = string(buf[:n]) + return err + }) + assert.NoError(t, err) + assert.Equal(t, "locked", seen) +} + +func TestNewSyncReaderLockedPropagatesError(t *testing.T) { + inner := strings.NewReader("data") + sr := NewSyncReader(inner) + + sentinel := errors.New("locked error") + err := sr.Locked(func(_ *strings.Reader) error { + return sentinel + }) + assert.Equal(t, sentinel, err) +} + +func TestNewSyncReaderMutex(t *testing.T) { + inner := strings.NewReader("shared mutex") + mutex := new(sync.Mutex) + sr := NewSyncReaderMutex[io.Reader](inner, mutex, mutex) + + buf := make([]byte, 12) + n, err := sr.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 12, n) + assert.Equal(t, "shared mutex", string(buf)) +} + +func TestNewSyncReaderMutexSeparateCloseMutex(t *testing.T) { + rc := &trackingReadCloser{Reader: strings.NewReader("separate")} + readMutex := new(sync.Mutex) + closeMutex := new(sync.Mutex) + sr := NewSyncReaderMutex(rc, readMutex, closeMutex) + + err := sr.Close() + assert.NoError(t, err) + assert.True(t, rc.closed) +} + +func TestNewSyncReaderConcurrentReads(t *testing.T) { + // Writes a large enough buffer so concurrent goroutines don't exhaust it + // immediately; we just want to confirm no races / panics occur. + data := strings.Repeat("x", 1024) + inner := strings.NewReader(data) + sr := NewSyncReader[io.Reader](inner) + + var wg sync.WaitGroup + for range 10 { + wg.Go(func() { + buf := make([]byte, 32) + _, _ = sr.Read(buf) + }) + } + wg.Wait() +} + +func TestNewSyncReaderConcurrentReadAndClose(t *testing.T) { + rc := &trackingReadCloser{Reader: strings.NewReader(strings.Repeat("y", 512))} + sr := NewSyncReader(rc) + + var wg sync.WaitGroup + for range 5 { + wg.Go(func() { + buf := make([]byte, 16) + _, _ = sr.Read(buf) + }) + } + wg.Go(func() { + _ = sr.Close() + }) + wg.Wait() +} + +func TestNewSyncReaderReadAll(t *testing.T) { + inner := strings.NewReader("read all content") + sr := NewSyncReader[io.Reader](inner) + + content, err := io.ReadAll(sr) + require.NoError(t, err) + assert.Equal(t, "read all content", string(content)) +} diff --git a/util/writer.go b/util/writer.go new file mode 100644 index 00000000..58a02283 --- /dev/null +++ b/util/writer.go @@ -0,0 +1,73 @@ +package util + +import ( + "io" + "sync" +) + +// Flusher allows a Writer to declare it may buffer content that can be flushed +type Flusher interface { + // Flush writes any pending bytes to output + Flush() error +} + +// FlushWriter attempts to flush a writer if it implements Flusher +func FlushWriter(writer io.Writer) (flushable bool, err error) { + var f Flusher + if f, flushable = writer.(Flusher); flushable { + err = f.Flush() + } + return +} + +// NewSyncWriter creates a new writer that is safe for concurrent use +func NewSyncWriter[W io.Writer](writer W) SyncWriter[W] { + return NewSyncWriterMutex(writer, new(sync.Mutex)) +} + +// NewSyncWriterMutex creates a new writer that is safe for concurrent use and synced with the specified sync.Mutex +func NewSyncWriterMutex[W io.Writer](writer W, mutex *sync.Mutex) SyncWriter[W] { + return &syncWriter[W]{writer: writer, mutex: mutex} +} + +// SyncWriter implements an io.Writer that is safe for concurrent use +type SyncWriter[W io.Writer] interface { + io.Writer + // Locked provides sync.Mutex locked access to the underlying writer + Locked(fn func(W) error) error +} + +type syncWriter[W io.Writer] struct { + writer io.Writer + mutex *sync.Mutex +} + +func (w *syncWriter[W]) Locked(fn func(writer W) error) error { + w.mutex.Lock() + defer w.mutex.Unlock() + return fn(w.writer.(W)) +} + +func (w *syncWriter[W]) Write(p []byte) (n int, err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + return w.writer.Write(p) +} + +func (w *syncWriter[W]) Flush() (err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + if f, ok := w.writer.(Flusher); ok { + err = f.Flush() + } + return +} + +func (w *syncWriter[W]) Close() (err error) { + w.mutex.Lock() + defer w.mutex.Unlock() + if f, ok := w.writer.(io.Closer); ok { + err = f.Close() + } + return +} diff --git a/util/writer_test.go b/util/writer_test.go new file mode 100644 index 00000000..790c7252 --- /dev/null +++ b/util/writer_test.go @@ -0,0 +1,207 @@ +package util + +import ( + "bytes" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// trackingWriteCloser is a bytes.Buffer that also tracks Close and Flush calls. +type trackingWriteCloser struct { + bytes.Buffer + closed bool + flushed bool + closeErr error + flushErr error +} + +func (t *trackingWriteCloser) Close() error { + t.closed = true + return t.closeErr +} + +func (t *trackingWriteCloser) Flush() error { + t.flushed = true + return t.flushErr +} + +// --- FlushWriter tests --- + +func TestFlushWriterFlushable(t *testing.T) { + tw := &trackingWriteCloser{} + flushable, err := FlushWriter(tw) + assert.True(t, flushable) + assert.NoError(t, err) + assert.True(t, tw.flushed) +} + +func TestFlushWriterNotFlushable(t *testing.T) { + var buf bytes.Buffer + flushable, err := FlushWriter(&buf) + assert.False(t, flushable) + assert.NoError(t, err) +} + +func TestFlushWriterFlushError(t *testing.T) { + sentinel := errors.New("flush error") + tw := &trackingWriteCloser{flushErr: sentinel} + flushable, err := FlushWriter(tw) + assert.True(t, flushable) + assert.Equal(t, sentinel, err) +} + +// --- syncWriter tests --- + +func TestSyncWriterWrite(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", buf.String()) +} + +func TestSyncWriterMultipleWrites(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + for _, s := range []string{"foo", "bar", "baz"} { + _, err := w.Write([]byte(s)) + require.NoError(t, err) + } + assert.Equal(t, "foobarbaz", buf.String()) +} + +func TestSyncWriterLocked(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + err := w.Locked(func(inner *bytes.Buffer) error { + _, err := inner.WriteString("locked") + return err + }) + assert.NoError(t, err) + assert.Equal(t, "locked", buf.String()) +} + +func TestSyncWriterLockedPropagatesError(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + sentinel := errors.New("locked error") + err := w.Locked(func(_ *bytes.Buffer) error { return sentinel }) + assert.Equal(t, sentinel, err) +} + +func TestSyncWriterFlushWithFlusher(t *testing.T) { + tw := &trackingWriteCloser{} + w := NewSyncWriter(tw) + + // syncWriter exposes Flush via a type assertion at call site + f, ok := w.(Flusher) + require.True(t, ok) + err := f.Flush() + assert.NoError(t, err) + assert.True(t, tw.flushed) +} + +func TestSyncWriterFlushError(t *testing.T) { + sentinel := errors.New("flush error") + tw := &trackingWriteCloser{flushErr: sentinel} + w := NewSyncWriter(tw) + + err := w.(Flusher).Flush() + assert.Equal(t, sentinel, err) +} + +func TestSyncWriterFlushNonFlusher(t *testing.T) { + // bytes.Buffer is not a Flusher; Flush should be a no-op + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + err := w.(Flusher).Flush() + assert.NoError(t, err) +} + +func TestSyncWriterCloseWithCloser(t *testing.T) { + tw := &trackingWriteCloser{} + w := NewSyncWriter(tw) + + type closer interface{ Close() error } + err := w.(closer).Close() + assert.NoError(t, err) + assert.True(t, tw.closed) +} + +func TestSyncWriterCloseError(t *testing.T) { + sentinel := errors.New("close error") + tw := &trackingWriteCloser{closeErr: sentinel} + w := NewSyncWriter(tw) + + type closer interface{ Close() error } + err := w.(closer).Close() + assert.Equal(t, sentinel, err) +} + +func TestSyncWriterCloseNonCloser(t *testing.T) { + // bytes.Buffer is not an io.Closer; Close should be a no-op + var buf bytes.Buffer + w := NewSyncWriter(&buf) + + type closer interface{ Close() error } + err := w.(closer).Close() + assert.NoError(t, err) +} + +func TestSyncWriterMutex(t *testing.T) { + var buf bytes.Buffer + mutex := new(sync.Mutex) + w := NewSyncWriterMutex[*bytes.Buffer](&buf, mutex) + + n, err := w.Write([]byte("shared")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + assert.Equal(t, "shared", buf.String()) +} + +func TestSyncWriterConcurrentWrites(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter[*bytes.Buffer](&buf) + + var wg sync.WaitGroup + for range 50 { + wg.Go(func() { + _, _ = w.Write([]byte("x")) + }) + } + wg.Wait() + assert.Equal(t, 50, buf.Len()) +} + +func TestSyncWriterConcurrentLockedAndWrite(t *testing.T) { + var buf bytes.Buffer + w := NewSyncWriter[*bytes.Buffer](&buf) + + var wg sync.WaitGroup + for range 10 { + wg.Add(2) + go func() { + defer wg.Done() + _, _ = w.Write([]byte("w")) + }() + go func() { + defer wg.Done() + _ = w.Locked(func(b *bytes.Buffer) error { + _, err := b.WriteString("l") + return err + }) + }() + } + wg.Wait() + assert.Equal(t, 20, buf.Len()) +} diff --git a/wrapper_test.go b/wrapper_test.go index fd4a8793..96cc8ddd 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "errors" "fmt" "io" @@ -423,8 +422,8 @@ func Example_runProfile() { } func TestRunRedirectOutputOfEchoProfile(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") ctx := &Context{ binary: "echo", @@ -434,12 +433,12 @@ func TestRunRedirectOutputOfEchoProfile(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "test", strings.TrimSpace(buffer.String())) + assert.Equal(t, "test", strings.TrimSpace(term.ReadRecording())) } func TestDryRun(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") wrapper := newResticWrapper(&Context{ flags: commandLineFlags{dryRun: true}, @@ -449,12 +448,12 @@ func TestDryRun(t *testing.T) { }) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "", buffer.String()) + assert.Equal(t, "", term.ReadRecording()) } func TestEnvProfileName(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "TestEnvProfileName") profile.RunBefore = []string{"echo profile name = $PROFILE_NAME"} @@ -466,12 +465,12 @@ func TestEnvProfileName(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "profile name = TestEnvProfileName\ntest\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) + assert.Equal(t, "profile name = TestEnvProfileName\ntest\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) } func TestEnvProfileCommand(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") profile.RunBefore = []string{"echo profile command = $PROFILE_COMMAND"} @@ -483,12 +482,12 @@ func TestEnvProfileCommand(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "profile command = test-command\ntest-command\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) + assert.Equal(t, "profile command = test-command\ntest-command\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) } func TestEnvError(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo error: $ERROR_MESSAGE"} @@ -500,12 +499,12 @@ func TestEnvError(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "error: 1 on profile 'name': exit status 1\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) + assert.Equal(t, "error: 1 on profile 'name': exit status 1\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) } func TestEnvErrorCommandLine(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo cmd: $ERROR_COMMANDLINE"} @@ -517,12 +516,12 @@ func TestEnvErrorCommandLine(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "cmd: \"exit\" \"1\"\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) + assert.Equal(t, "cmd: \"exit\" \"1\"\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) } func TestEnvErrorExitCode(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo exit-code: $ERROR_EXIT_CODE"} @@ -534,12 +533,12 @@ func TestEnvErrorExitCode(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "exit-code: 5\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) + assert.Equal(t, "exit-code: 5\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) } func TestEnvStderr(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo stderr: $ERROR_STDERR"} @@ -552,7 +551,7 @@ func TestEnvStderr(t *testing.T) { wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "stderr: error_message", strings.TrimSpace(strings.ReplaceAll(buffer.String(), "\r\n", "\n"))) + assert.Equal(t, "stderr: error_message", strings.TrimSpace(strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n"))) } func TestRunProfileWithSetPIDCallback(t *testing.T) { @@ -1079,8 +1078,8 @@ func TestRunShellCommands(t *testing.T) { } func TestRunStreamErrorHandler(t *testing.T) { - buffer := &bytes.Buffer{} - term.SetOutput(buffer) + term.StartRecording(term.RecordOutput) + defer term.StopRecording() errorCommand := `echo "detected error in $PROFILE_COMMAND"` @@ -1097,7 +1096,7 @@ func TestRunStreamErrorHandler(t *testing.T) { err := wrapper.runProfile() require.NoError(t, err) - assert.Contains(t, buffer.String(), "detected error in backup") + assert.Contains(t, term.ReadRecording(), "detected error in backup") } func TestRunStreamErrorHandlerDoesNotBreakCommand(t *testing.T) { From a21a1d68cd32d6a7fedf84303f9ba6767ae54b28 Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 15:14:26 +0100 Subject: [PATCH 02/15] Refactor terminal handling in wrapper and tests - Removed the Size method from the Terminal struct as it was no longer needed. - Updated the wrapper to directly use the terminal's Stdout and Stderr methods instead of the term package. - Replaced term output handling in tests with direct terminal instances to improve clarity and maintainability. - Ensured that FlushAllOutput is called on the terminal instance in relevant places. - Adjusted tests to instantiate the terminal correctly, ensuring consistent output handling across test cases. --- commands_display.go | 4 +- flags.go | 2 +- integration_test.go | 26 ++- term/default.go | 17 ++ term/term.go | 227 ------------------ term/terminal.go | 11 - wrapper.go | 21 +- wrapper_streamsource.go | 3 +- wrapper_test.go | 504 ++++++++++++++++++++++------------------ 9 files changed, 326 insertions(+), 489 deletions(-) diff --git a/commands_display.go b/commands_display.go index f11ccbc9..d6b87b0e 100644 --- a/commands_display.go +++ b/commands_display.go @@ -22,9 +22,9 @@ import ( ) func displayWriter(terminal *term.Terminal) (out func(args ...any) io.Writer, closer func()) { - var output io.Writer = terminal + var output io.Writer = terminal.Stdout() if terminal.StdoutIsTerminal() { - if width, _ := terminal.Size(); width > 10 { + if width, _ := term.Size(); width > 10 { output = ansi.NewLineLengthWriter(terminal, width) } } diff --git a/flags.go b/flags.go index 87dfd51c..1b51c7fd 100644 --- a/flags.go +++ b/flags.go @@ -139,7 +139,7 @@ func loadFlags(args []string) (*pflag.FlagSet, commandLineFlags, error) { flagset.SetInterspersed(false) // Store usage help for help command - width, _ := term.OsStdoutTerminalSize() + width, _ := term.Size() flags.usagesHelp = flagset.FlagUsagesWrapped(width) if err := flagset.Parse(args); err != nil { diff --git a/integration_test.go b/integration_test.go index ad567d4d..bbf45e59 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "path/filepath" "strings" "testing" @@ -119,22 +120,23 @@ func TestFromConfigFileToCommandLine(t *testing.T) { require.NoError(t, err) require.NotNil(t, profile) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) + ctx := &Context{ request: Request{ profile: fixture.profileName, arguments: fixture.cmdlineArgs, }, - binary: echoBinary, - profile: profile, - command: fixture.commandName, + binary: echoBinary, + profile: profile, + command: fixture.commandName, + terminal: terminal, } wrapper := newResticWrapper(ctx) - // setting the output via the package global setter could lead to some issues - // when some tests are running in parallel. I should fix that at some point :-/ - term.StartRecording(term.RecordOutput) err = wrapper.runCommand(fixture.commandName) - stdout := term.StopRecording() + stdout := buffer.String() require.NoError(t, err) @@ -159,6 +161,9 @@ func TestFromConfigFileToCommandLine(t *testing.T) { require.NoError(t, err) require.NotNil(t, profile) + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) + ctx := &Context{ request: Request{ profile: fixture.profileName, @@ -168,13 +173,12 @@ func TestFromConfigFileToCommandLine(t *testing.T) { profile: profile, command: fixture.commandName, legacyArgs: true, + terminal: terminal, } wrapper := newResticWrapper(ctx) - // setting the output via the package global setter could lead to some issues - // when some tests are running in parallel. I should fix that at some point :-/ - term.StartRecording(term.RecordOutput) + err = wrapper.runCommand(fixture.commandName) - content := term.StopRecording() + content := buffer.String() require.NoError(t, err) diff --git a/term/default.go b/term/default.go index 340b0e07..4e0bcb08 100644 --- a/term/default.go +++ b/term/default.go @@ -1,5 +1,11 @@ package term +import ( + "os" + + "golang.org/x/term" +) + var defaultTerminal *Terminal // Get returns the default terminal. It will be initialized on first use with NewTerminal() if not set before. @@ -15,3 +21,14 @@ func Set(t *Terminal) *Terminal { defaultTerminal = t return defaultTerminal } + +// Size returns the width and height of the terminal session +func Size() (width, height int) { + fd := fdToInt(os.Stdout.Fd()) + var err error + width, height, err = term.GetSize(fd) + if err != nil { + width, height = 0, 0 + } + return +} diff --git a/term/term.go b/term/term.go index 04df3e84..eaf18c16 100644 --- a/term/term.go +++ b/term/term.go @@ -2,17 +2,13 @@ package term import ( "bufio" - "bytes" "fmt" "io" "os" "strings" "sync/atomic" - "time" "github.com/creativeprojects/resticprofile/util" - "github.com/creativeprojects/resticprofile/util/ansi" - "github.com/mattn/go-colorable" "golang.org/x/term" ) @@ -32,7 +28,6 @@ const ( func init() { enableColors.Store(true) - go handleStatus() // must be last { setOutput(os.Stdout) @@ -40,92 +35,6 @@ func init() { } } -func handleStatus() { - ticker := time.NewTicker(time.Second / StatusFPS) - defer ticker.Stop() - - var waiting []chan bool - respondWaiting := func(result bool) { - for _, request := range waiting { - request <- result - close(request) - } - waiting = nil - } - defer respondWaiting(false) - - var newStatus, status []string - buffer := &bytes.Buffer{} - for { - select { - case lines := <-statusChannel: - newStatus = lines - - case request := <-statusWaitChannel: - waiting = append(waiting, request) - - case <-ticker.C: - if status != nil && OutputIsTerminal() { - width, height := OsStdoutTerminalSize() - noAnsi := !IsColorableOutput() - if height < 1 { - continue - } else if noAnsi { - newStatus = newStatus[1:] // strip first empty line - height = 1 - } - if width >= 60 { - width -= 2 - } else if width >= 80 { - width -= 4 // right margin - } - - last := truncate(status, height) - printable := truncate(newStatus, height) - removedLines := len(last) - len(printable) - if removedLines > 0 { - filler := make([]string, removedLines, removedLines+len(printable)) - printable = append(filler, printable...) - } - - if len(printable) > 0 { - buffer.Reset() - for index, line := range printable { - runes := []rune(strings.ReplaceAll(line, "\n", " ")) - _, maxIndex := ansi.RunesLength(runes, width) - runes = truncate(runes, maxIndex) - - if noAnsi { - if remaining := width - len(runes); remaining > 0 { - for remaining > 0 { - runes = append(runes, ' ') - remaining-- - } - } - _, _ = fmt.Fprintf(buffer, "\r%s\r", string(runes)) - } else { - eol := "\n" - if index+1 == len(printable) { - eol = "\r" - } - _, _ = fmt.Fprintf(buffer, "\r%s%s%s%s", ansi.ClearLine, string(runes), ansi.Reset, eol) - } - } - - if !noAnsi { - buffer.WriteString(ansi.CursorUpLeftN(len(printable) - 1)) - } - - _, _ = buffer.WriteTo(getColorableOutput()) - buffer.Reset() - } - } - status = newStatus - respondWaiting(true) - } - } -} - func truncate[E any](src []E, maxLength int) []E { if len(src) > maxLength { return src[:maxLength] @@ -133,26 +42,6 @@ func truncate[E any](src []E, maxLength int) []E { return src } -// SetStatus sets a status line(s) that is printed when the output is an interactive terminal -// -// Deprecated: use term.Terminal instead -func SetStatus(line []string) { - // Clone lines and add empty line on top (= cursor position after printing status) - if line != nil { - line = append([]string{""}, line...) - } - statusChannel <- line -} - -// WaitForStatus blocks until the previously provided status was applied -// -// Deprecated: use term.Terminal instead -func WaitForStatus() bool { - request := make(chan bool, 1) - statusWaitChannel <- request - return <-request -} - // ReadPassword reads a password without echoing it to the terminal. // // Deprecated: use term.Terminal instead @@ -186,26 +75,6 @@ func OsStdoutIsTerminal() bool { return isTerminal(os.Stdout) } -// OsStdoutTerminalSize returns the width and height of the terminal session -// -// Deprecated: use term.Terminal instead -func OsStdoutTerminalSize() (width, height int) { - fd := fdToInt(os.Stdout.Fd()) - var err error - width, height, err = term.GetSize(fd) - if err != nil { - width, height = 0, 0 - } - return -} - -// OutputIsTerminal returns true if GetOutput sends to an interactive terminal -// -// Deprecated: use term.Terminal instead -func OutputIsTerminal() bool { - return GetOutput() == os.Stdout && OsStdoutIsTerminal() -} - // SetOutput changes the default output for the Print* functions // // Deprecated: use term.Terminal instead @@ -223,7 +92,6 @@ func setOutput(w io.Writer) { } termOutput.Store(&w) colorOutput.Store(nil) - SetStatus(nil) } // GetOutput returns the default output of the Print* functions @@ -236,29 +104,6 @@ func GetOutput() (out io.Writer) { return } -// getColorableOutput returns an output supporting ANSI color if output is a terminal -func getColorableOutput() (out io.Writer) { - if v := colorOutput.Load(); v != nil { - out = *v - } - if out == nil { - if IsColorableOutput() { - out = colorable.NewColorable(os.Stdout) - } else { - out = colorable.NewNonColorable(outputWriter()) - } - colorOutput.Store(&out) - } - return out -} - -// IsColorableOutput tells whether GetColorableOutput supports ANSI color (and control characters) or discards ANSI -// -// Deprecated: use term.Terminal instead -func IsColorableOutput() bool { - return enableColors.Load() && OutputIsTerminal() -} - // SetErrorOutput changes the error output for the Print* functions // // Deprecated: use term.Terminal instead @@ -294,75 +139,3 @@ func SetAllOutput(w io.Writer) { SetOutput(w) setErrorOutput(GetOutput()) } - -// FlushAllOutput flushes all buffered output (if supported by the underlying Writer). -// -// Deprecated: use term.Terminal instead -func FlushAllOutput() { - for _, writer := range []io.Writer{GetOutput(), GetErrorOutput()} { - _, _ = util.FlushWriter(writer) - } -} - -var recording = &outputRecording{} - -// Deprecated: use term.Terminal instead -func StartRecording(mode RecordMode) { - recording.lock.Lock() - defer recording.lock.Unlock() - if recording.buffer != nil { - return - } - - recording.buffer = new(bytes.Buffer) - recording.writer = util.NewSyncWriterMutex(recording.buffer, &recording.lock) - - if mode != RecordError { - recording.output = GetOutput() - setOutput(recording.writer) - } - if mode != RecordOutput { - recording.error = GetErrorOutput() - setErrorOutput(recording.writer) - } -} - -// Deprecated: use term.Terminal instead -func ReadRecording() (content string) { - recording.lock.Lock() - defer recording.lock.Unlock() - if recording.buffer != nil { - content = recording.buffer.String() - recording.buffer.Reset() - } - return -} - -// Deprecated: use term.Terminal instead -func StopRecording() (content string) { - recording.lock.Lock() - defer recording.lock.Unlock() - if recording.buffer != nil { - if recording.output != nil && recording.writer == GetOutput() { - setOutput(recording.output) - recording.output = nil - } - - if recording.error != nil && recording.writer == GetErrorOutput() { - setErrorOutput(recording.error) - recording.error = nil - } - - content = recording.buffer.String() - recording.writer = nil - recording.buffer = nil - } - return -} - -func outputWriter() io.Writer { - if PrintToError { - return GetErrorOutput() - } - return GetOutput() -} diff --git a/term/terminal.go b/term/terminal.go index a14834e1..8288635b 100644 --- a/term/terminal.go +++ b/term/terminal.go @@ -115,17 +115,6 @@ func (t *Terminal) getColorableWriter(w io.Writer) io.Writer { return colorable.NewNonColorable(w) } -// Size returns the width and height of the terminal session -func (t *Terminal) Size() (width, height int) { - fd := fdToInt(os.Stdout.Fd()) - var err error - width, height, err = term.GetSize(fd) - if err != nil { - width, height = 0, 0 - } - return -} - // FlushAllOutput flushes all buffered output (if supported by the underlying Writer). func (t *Terminal) FlushAllOutput() { for _, writer := range []io.Writer{t.colorableStdout, t.colorableStderr, t.stdout, t.stderr} { diff --git a/wrapper.go b/wrapper.go index be40f235..7364bbfc 100644 --- a/wrapper.go +++ b/wrapper.go @@ -22,7 +22,6 @@ import ( "github.com/creativeprojects/resticprofile/monitor/hook" "github.com/creativeprojects/resticprofile/restic" "github.com/creativeprojects/resticprofile/shell" - "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util" "github.com/creativeprojects/resticprofile/util/collect" ) @@ -438,8 +437,8 @@ func (r *resticWrapper) prepareCommand(command string, args *shell.Args, allowEx rCommand := newShellCommand(binary, arguments, env, r.getShell(), r.dryRun, r.sigChan, r.setPID) rCommand.publicArgs = publicArguments // stdout are stderr are coming from the default terminal (in case they're redirected) - rCommand.stdout = term.GetOutput() - rCommand.stderr = term.GetErrorOutput() + rCommand.stdout = r.ctx.terminal.Stdout() + rCommand.stderr = r.ctx.terminal.Stderr() rCommand.streamError = r.profile.StreamError rCommand.dir = dir @@ -545,7 +544,7 @@ func (r *resticWrapper) runCommand(command string) error { if len(r.progress) > 0 { if r.profile.Backup.ExtendedStatus { rCommand.scanOutput = shell.ScanBackupJson - } else if !term.OsStdoutIsTerminal() { + } else if !r.ctx.terminal.StdoutIsTerminal() { // restic detects its output is not a terminal and no longer displays the monitor. // Scan plain output only if resticprofile is not run from a terminal (e.g. schedule) rCommand.scanOutput = shell.ScanBackupPlain @@ -629,9 +628,9 @@ func (r *resticWrapper) runShellCommands(commands []string, commandsType, comman // creating command rCommand := newShellCommand(shellCommand, nil, env, r.getShell(), r.dryRun, r.sigChan, r.setPID) // stdout are stderr are coming from the default terminal (in case they're redirected) - rCommand.stdout = term.GetOutput() - rCommand.stderr = term.GetErrorOutput() - term.FlushAllOutput() + rCommand.stdout = r.ctx.terminal.Stdout() + rCommand.stderr = r.ctx.terminal.Stderr() + r.ctx.terminal.FlushAllOutput() _, stderr, err := runShellCommand(rCommand) if err != nil { err = fmt.Errorf("%s on profile '%s': %w", commandsType, r.profile.Name, err) @@ -659,9 +658,9 @@ func (r *resticWrapper) runFinalShellCommands(command string, fail error) { // creating command rCommand := newShellCommand(cmd, nil, env, r.getShell(), r.dryRun, r.sigChan, r.setPID) // stdout are stderr are coming from the default terminal (in case they're redirected) - rCommand.stdout = term.GetOutput() - rCommand.stderr = term.GetErrorOutput() - term.FlushAllOutput() + rCommand.stdout = r.ctx.terminal.Stdout() + rCommand.stderr = r.ctx.terminal.Stderr() + r.ctx.terminal.FlushAllOutput() _, _, err := runShellCommand(rCommand) if err != nil { clog.Errorf("run-finally command %d/%d failed ('%s' on profile '%s'): %s", @@ -694,7 +693,7 @@ func (r *resticWrapper) sendFinally(monitoring config.SendMonitoringSections, co func (r *resticWrapper) sendMonitoring(sections []config.SendMonitoringSection, command, sendType string, err error) { for i, section := range sections { clog.Debugf("starting %q from %s %d/%d", sendType, command, i+1, len(sections)) - term.FlushAllOutput() + r.ctx.terminal.FlushAllOutput() err := r.sender.Send(section, r.getContextWithError(err), r.profile.GetEnvironment(true)) if err != nil { clog.Warningf("%q returned an error: %s", sendType, err.Error()) diff --git a/wrapper_streamsource.go b/wrapper_streamsource.go index 205abad4..8b4c027a 100644 --- a/wrapper_streamsource.go +++ b/wrapper_streamsource.go @@ -14,7 +14,6 @@ import ( "time" "github.com/creativeprojects/clog" - "github.com/creativeprojects/resticprofile/term" ) func (r *resticWrapper) prepareStreamSource() (io.ReadCloser, error) { @@ -77,7 +76,7 @@ func (r *resticWrapper) prepareCommandStreamSource() (io.ReadCloser, error) { clog.Debugf("starting 'stdin-command' command %d/%d: %s", i+1, len(r.profile.Backup.StdinCommand), sourceCommand) rCommand := newShellCommand(sourceCommand, nil, env, r.getShell(), r.dryRun, commandSignals, nil) rCommand.stdout = bufferedWriter - rCommand.stderr = term.GetErrorOutput() + rCommand.stderr = r.ctx.terminal.Stderr() _, stderr, err := runShellCommand(rCommand) if err != nil { diff --git a/wrapper_test.go b/wrapper_test.go index 96cc8ddd..ad60a36a 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "errors" "fmt" "io" @@ -181,10 +182,11 @@ func TestFilteredArgumentsRegression(t *testing.T) { profile, err := cfg.GetProfile("default") require.NoError(t, err) wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - binary: "restic", - profile: profile, - command: "test", + flags: commandLineFlags{dryRun: true}, + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), }) for command, commandline := range test.expected { @@ -202,9 +204,10 @@ func TestGetEmptyEnvironment(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "restic", - profile: profile, - command: "test", + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) env := wrapper.getEnvironment(false) @@ -220,9 +223,10 @@ func TestGetSingleEnvironment(t *testing.T) { } profile.ResolveConfiguration() ctx := &Context{ - binary: "restic", - profile: profile, - command: "test", + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) env := wrapper.getEnvironment(false) @@ -239,9 +243,10 @@ func TestGetMultipleEnvironment(t *testing.T) { } profile.ResolveConfiguration() ctx := &Context{ - binary: "restic", - profile: profile, - command: "test", + binary: "restic", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) @@ -266,9 +271,10 @@ func TestPreProfileScriptFail(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunBefore = []string{"exit 1"} // this should both work on unix shell and windows batch ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -281,9 +287,10 @@ func TestPostProfileScriptFail(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunAfter = []string{"exit 1"} // this should both work on unix shell and windows batch ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -295,9 +302,10 @@ func TestRunEchoProfile(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -312,9 +320,10 @@ func TestPostProfileAfterFail(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunAfter = []string{"echo failed > " + testFile} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -331,9 +340,10 @@ func TestPostFailProfile(t *testing.T) { profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo failed > " + testFile} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -367,9 +377,10 @@ func TestFinallyProfile(t *testing.T) { t.Run("backup-before-profile", func(t *testing.T) { newProfile() ctx := &Context{ - binary: "echo", - profile: profile, - command: "backup", + binary: "echo", + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -381,9 +392,10 @@ func TestFinallyProfile(t *testing.T) { newProfile() profile.RunFinally = nil ctx := &Context{ - binary: "echo", - profile: profile, - command: "backup", + binary: "echo", + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -394,9 +406,10 @@ func TestFinallyProfile(t *testing.T) { t.Run("on-error", func(t *testing.T) { newProfile() ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -406,12 +419,12 @@ func TestFinallyProfile(t *testing.T) { } func Example_runProfile() { - term.SetOutput(os.Stdout) profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -422,136 +435,144 @@ func Example_runProfile() { } func TestRunRedirectOutputOfEchoProfile(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "test", strings.TrimSpace(term.ReadRecording())) + assert.Equal(t, "test", strings.TrimSpace(buffer.String())) } func TestDryRun(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - binary: "echo", - profile: profile, - command: "test", + flags: commandLineFlags{dryRun: true}, + binary: "echo", + profile: profile, + command: "test", + terminal: terminal, }) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "", term.ReadRecording()) + assert.Equal(t, "", buffer.String()) } func TestEnvProfileName(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "TestEnvProfileName") profile.RunBefore = []string{"echo profile name = $PROFILE_NAME"} ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "profile name = TestEnvProfileName\ntest\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) + assert.Equal(t, "profile name = TestEnvProfileName\ntest\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) } func TestEnvProfileCommand(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunBefore = []string{"echo profile command = $PROFILE_COMMAND"} ctx := &Context{ - binary: "echo", - profile: profile, - command: "test-command", + binary: "echo", + profile: profile, + command: "test-command", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.NoError(t, err) - assert.Equal(t, "profile command = test-command\ntest-command\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) + assert.Equal(t, "profile command = test-command\ntest-command\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) } func TestEnvError(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo error: $ERROR_MESSAGE"} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "error: 1 on profile 'name': exit status 1\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) + assert.Equal(t, "error: 1 on profile 'name': exit status 1\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) } func TestEnvErrorCommandLine(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo cmd: $ERROR_COMMANDLINE"} ctx := &Context{ - binary: "exit", - profile: profile, - command: "1", + binary: "exit", + profile: profile, + command: "1", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "cmd: \"exit\" \"1\"\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) + assert.Equal(t, "cmd: \"exit\" \"1\"\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) } func TestEnvErrorExitCode(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo exit-code: $ERROR_EXIT_CODE"} ctx := &Context{ - binary: "exit", - profile: profile, - command: "5", + binary: "exit", + profile: profile, + command: "5", + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "exit-code: 5\n", strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n")) + assert.Equal(t, "exit-code: 5\n", strings.ReplaceAll(buffer.String(), "\r\n", "\n")) } func TestEnvStderr(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) profile := config.NewProfile(nil, "name") profile.RunAfterFail = []string{"echo stderr: $ERROR_STDERR"} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "command", - request: Request{arguments: []string{"--stderr", "error_message", "--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "command", + request: Request{arguments: []string{"--stderr", "error_message", "--exit", "1"}}, + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() assert.Error(t, err) - assert.Equal(t, "stderr: error_message", strings.TrimSpace(strings.ReplaceAll(term.ReadRecording(), "\r\n", "\n"))) + assert.Equal(t, "stderr: error_message", strings.TrimSpace(strings.ReplaceAll(buffer.String(), "\r\n", "\n"))) } func TestRunProfileWithSetPIDCallback(t *testing.T) { @@ -561,9 +582,10 @@ func TestRunProfileWithSetPIDCallback(t *testing.T) { profile.Lock = filepath.Join(os.TempDir(), fmt.Sprintf("%s%d%d.tmp", "TestRunProfileWithSetPIDCallback", time.Now().UnixNano(), os.Getpid())) t.Logf("lockfile = %s", profile.Lock) ctx := &Context{ - binary: "echo", - profile: profile, - command: "test", + binary: "echo", + profile: profile, + command: "test", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -575,9 +597,10 @@ func TestInitializeNoError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitialize() @@ -589,10 +612,11 @@ func TestInitializeWithError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitialize() @@ -605,9 +629,10 @@ func TestInitializeCopyNoError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Copy = &config.CopySection{InitializeCopyChunkerParams: maybe.False()} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitializeCopy() @@ -620,10 +645,11 @@ func TestInitializeCopyWithError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Copy = &config.CopySection{InitializeCopyChunkerParams: maybe.False()} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runInitializeCopy() @@ -635,9 +661,10 @@ func TestCheckNoError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCheck() @@ -649,10 +676,11 @@ func TestCheckWithError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCheck() @@ -664,9 +692,10 @@ func TestRetentionNoError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runRetention() @@ -678,10 +707,11 @@ func TestRetentionWithError(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "10"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "10"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runRetention() @@ -710,10 +740,11 @@ func TestBackupWithStreamSource(t *testing.T) { profile.Backup = &config.BackupSection{} signals := make(chan os.Signal, 1) ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "stdin-test", - sigChan: signals, + binary: mockBinary, + profile: profile, + command: "stdin-test", + sigChan: signals, + terminal: term.NewTerminal(), } wrapper = newResticWrapper(ctx) return @@ -853,9 +884,10 @@ func TestBackupWithSuccess(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -868,10 +900,11 @@ func TestBackupWithError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "1"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -898,12 +931,13 @@ func TestBackupWithResticLockFailureRetried(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - global: global, - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, - sigChan: sigChan, + global: global, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, + sigChan: sigChan, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.lockWait = &lockWait @@ -934,12 +968,13 @@ func TestBackupWithResticLockFailureCancelled(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - global: global, - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, - sigChan: sigChan, + global: global, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--stderr", "@" + tempfile, "--exit", "1"}}, + sigChan: sigChan, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.lockWait = &lockWait @@ -959,10 +994,11 @@ func TestBackupWithNoConfiguration(t *testing.T) { profile := config.NewProfile(nil, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "1"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -975,10 +1011,11 @@ func TestBackupWithNoConfigurationButStatusFile(t *testing.T) { profile := config.NewProfile(nil, "name") profile.StatusFile = "status.json" ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "1"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "1"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.addProgress(status.NewProgress(profile, status.NewStatus("status.json"))) @@ -992,10 +1029,11 @@ func TestBackupWithWarningAsError(t *testing.T) { profile := config.NewProfile(nil, "name") profile.Backup = &config.BackupSection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "3"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "3"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -1008,10 +1046,11 @@ func TestBackupWithSupressedWarnings(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Backup = &config.BackupSection{NoErrorOnWarning: true} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", - request: Request{arguments: []string{"--exit", "3"}}, + binary: mockBinary, + profile: profile, + command: "", + request: Request{arguments: []string{"--exit", "3"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runCommand("backup") @@ -1042,9 +1081,10 @@ func TestRunShellCommands(t *testing.T) { t.Run(fmt.Sprintf("run-before '%s'", command), func(t *testing.T) { section.RunBefore = []string{"exit 2"} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: command, + binary: mockBinary, + profile: profile, + command: command, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -1060,9 +1100,10 @@ func TestRunShellCommands(t *testing.T) { t.Run(fmt.Sprintf("run-after '%s'", command), func(t *testing.T) { section.RunAfter = []string{"exit 2"} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: command, + binary: mockBinary, + profile: profile, + command: command, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() @@ -1078,8 +1119,8 @@ func TestRunShellCommands(t *testing.T) { } func TestRunStreamErrorHandler(t *testing.T) { - term.StartRecording(term.RecordOutput) - defer term.StopRecording() + buffer := new(bytes.Buffer) + terminal := term.NewTerminal(term.WithStdout(buffer)) errorCommand := `echo "detected error in $PROFILE_COMMAND"` @@ -1087,16 +1128,17 @@ func TestRunStreamErrorHandler(t *testing.T) { profile.Backup = &config.BackupSection{} profile.StreamError = []config.StreamErrorSection{{Pattern: ".+error-line.+", Run: errorCommand}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", - request: Request{arguments: []string{"--stderr", "--error-line--"}}, + binary: mockBinary, + profile: profile, + command: "backup", + request: Request{arguments: []string{"--stderr", "--error-line--"}}, + terminal: terminal, } wrapper := newResticWrapper(ctx) err := wrapper.runProfile() require.NoError(t, err) - assert.Contains(t, term.ReadRecording(), "detected error in backup") + assert.Contains(t, buffer.String(), "detected error in backup") } func TestRunStreamErrorHandlerDoesNotBreakCommand(t *testing.T) { @@ -1106,10 +1148,11 @@ func TestRunStreamErrorHandlerDoesNotBreakCommand(t *testing.T) { profile.Backup = &config.BackupSection{} profile.StreamError = []config.StreamErrorSection{{Pattern: ".+error-line.+", Run: "exit 1"}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", - request: Request{arguments: []string{"--stderr", "--error-line--"}}, + binary: mockBinary, + profile: profile, + command: "backup", + request: Request{arguments: []string{"--stderr", "--error-line--"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) @@ -1124,10 +1167,11 @@ func TestStreamErrorHandlerWithInvalidRegex(t *testing.T) { profile.Backup = &config.BackupSection{} profile.StreamError = []config.StreamErrorSection{{Pattern: "(", Run: "echo pass"}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", - request: Request{arguments: []string{}}, + binary: mockBinary, + profile: profile, + command: "backup", + request: Request{arguments: []string{}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) @@ -1140,9 +1184,10 @@ func TestCanRetryAfterErrorDontFailWhenNoOutputAnalysis(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", + binary: mockBinary, + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) summary := monitor.Summary{} @@ -1163,9 +1208,10 @@ func TestCanRetryAfterRemoteStaleLockFailure(t *testing.T) { profile.Repository = config.NewConfidentialValue("my-repo") profile.ForceLock = true ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", + binary: mockBinary, + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.startTime = time.Now() @@ -1228,9 +1274,10 @@ func TestCanRetryAfterRemoteLockFailure(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Repository = config.NewConfidentialValue("my-repo") ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "backup", + binary: mockBinary, + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) wrapper.startTime = time.Now() @@ -1283,10 +1330,11 @@ func TestCanUseResticLockRetry(t *testing.T) { getWrapper := func() *resticWrapper { wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - binary: "restic", - profile: profile, - command: constants.CommandBackup, + flags: commandLineFlags{dryRun: true}, + binary: "restic", + profile: profile, + command: constants.CommandBackup, + terminal: term.NewTerminal(), }) wrapper.startTime = time.Now() wrapper.global.ResticLockRetryAfter = 1 * time.Minute @@ -1369,23 +1417,26 @@ func TestLocksAndLockWait(t *testing.T) { profile.Lock = filepath.Join(os.TempDir(), fmt.Sprintf("%s%d%d.tmp", "TestLockWait", time.Now().UnixNano(), os.Getpid())) defer os.Remove(profile.Lock) - term.SetOutput(os.Stdout) + terminal := term.NewTerminal() ctx1 := &Context{ - binary: mockBinary, - profile: profile, - command: constants.CommandBackup, - request: Request{arguments: []string{"--sleep", "1500"}}, + binary: mockBinary, + profile: profile, + command: constants.CommandBackup, + request: Request{arguments: []string{"--sleep", "1500"}}, + terminal: terminal, } ctx2 := &Context{ - binary: mockBinary, - profile: profile, - command: constants.CommandBackup, + binary: mockBinary, + profile: profile, + command: constants.CommandBackup, + terminal: terminal, } ctx3 := &Context{ - binary: mockBinary, - profile: profile, - command: constants.CommandBackup, + binary: mockBinary, + profile: profile, + command: constants.CommandBackup, + terminal: terminal, } w1 := newResticWrapper(ctx1) w2 := newResticWrapper(ctx2) @@ -1674,11 +1725,12 @@ func TestRunInitCopyCommand(t *testing.T) { defer clog.SetDefaultLogger(defaultLogger) wrapper := newResticWrapper(&Context{ - flags: commandLineFlags{dryRun: true}, - global: config.NewGlobal(), - binary: "test", - profile: testCase.profile, - command: "copy", + flags: commandLineFlags{dryRun: true}, + global: config.NewGlobal(), + binary: "test", + profile: testCase.profile, + command: "copy", + terminal: term.NewTerminal(), }) // 1. run init command with copy profile err := wrapper.runInitializeCopy() @@ -1702,9 +1754,10 @@ func TestCopyNoSnapshot(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Copy = &config.CopySection{} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := shell.NewArgs() @@ -1718,9 +1771,10 @@ func TestCopySnapshot(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") profile.Copy = &config.CopySection{Snapshots: []string{"snapshot1", "snapshot2"}} ctx := &Context{ - binary: mockBinary, - profile: profile, - command: "", + binary: mockBinary, + profile: profile, + command: "", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := shell.NewArgs() @@ -1736,9 +1790,10 @@ func TestPrepareCommandShouldEscapeBinary(t *testing.T) { profile := config.NewProfile(&config.Config{}, "name") ctx := &Context{ - binary: "/full path to/restic", - profile: profile, - command: "backup", + binary: "/full path to/restic", + profile: profile, + command: "backup", + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := shell.NewArgs() @@ -1751,10 +1806,11 @@ func TestRunUnlockWithCommandLineFlags(t *testing.T) { profile := config.NewProfile(&config.Config{}, "TestProfile") ctx := &Context{ - binary: "restic", - profile: profile, - command: constants.CommandForget, - request: Request{arguments: []string{"some-string", "--some-flag", "-n"}}, + binary: "restic", + profile: profile, + command: constants.CommandForget, + request: Request{arguments: []string{"some-string", "--some-flag", "-n"}}, + terminal: term.NewTerminal(), } wrapper := newResticWrapper(ctx) args := profile.GetCommandFlags(constants.CommandUnlock) From 248458ad045f464cbac0735828393e45b333a819 Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 15:38:06 +0100 Subject: [PATCH 03/15] feat: refactor terminal handling and remove deprecated functions --- commands_display.go | 2 +- commands_test.go | 5 ++- logger.go | 25 ++++++++----- logger_test.go | 7 ++-- main.go | 5 ++- schedule/handler_darwin.go | 2 +- term/recording.go | 75 -------------------------------------- term/term.go | 34 +++-------------- util/filewriter.go | 2 +- 9 files changed, 35 insertions(+), 122 deletions(-) delete mode 100644 term/recording.go diff --git a/commands_display.go b/commands_display.go index d6b87b0e..9b05262c 100644 --- a/commands_display.go +++ b/commands_display.go @@ -22,7 +22,7 @@ import ( ) func displayWriter(terminal *term.Terminal) (out func(args ...any) io.Writer, closer func()) { - var output io.Writer = terminal.Stdout() + output := terminal.Stdout() if terminal.StdoutIsTerminal() { if width, _ := term.Size(); width > 10 { output = ansi.NewLineLengthWriter(terminal, width) diff --git a/commands_test.go b/commands_test.go index f1c8d2a7..3b1855f0 100644 --- a/commands_test.go +++ b/commands_test.go @@ -321,7 +321,6 @@ func TestGenerateCommand(t *testing.T) { } func TestShowSchedules(t *testing.T) { - buffer := &bytes.Buffer{} create := func(command string, at ...string) *config.Schedule { origin := config.ScheduleOrigin("default", command) return config.NewDefaultSchedule(nil, origin, at...) @@ -349,7 +348,9 @@ schedule check@default: lock-mode: default capture-environment: RESTIC_* `) - showSchedules(buffer, schedules) + buffer := &bytes.Buffer{} + terminal := term.NewTerminal(term.WithStdout(buffer)) + showSchedules(terminal.Stdout(), schedules) assert.Equal(t, expected, strings.TrimSpace(buffer.String())) } diff --git a/logger.go b/logger.go index 40c5d25b..816609b4 100644 --- a/logger.go +++ b/logger.go @@ -49,7 +49,7 @@ func setupRemoteLogger(flags commandLineFlags, client *remote.Client) { clog.SetDefaultLogger(logger) } -func setupTargetLogger(flags commandLineFlags, logTarget, commandOutput string) (io.Closer, error) { +func setupTargetLogger(flags commandLineFlags, terminal *term.Terminal, logTarget, commandOutput string) (io.Closer, *term.Terminal, error) { var ( handler LogCloser file io.Writer @@ -63,29 +63,36 @@ func setupTargetLogger(flags commandLineFlags, logTarget, commandOutput string) handler, file, err = getFileHandler(logTarget) } if err != nil { - return nil, err + return nil, nil, err } // use the console handler as a backup logger := newFilteredLogger(flags, clog.NewSafeHandler(handler, clog.NewConsoleHandler("", log.LstdFlags))) // default logger added with level filtering clog.SetDefaultLogger(logger) + var newTerminal *term.Terminal + // also redirect all terminal output if file != nil { - if all, toLog := parseCommandOutput(commandOutput); all { - term.SetOutput(io.MultiWriter(file, term.GetOutput())) - term.SetErrorOutput(io.MultiWriter(file, term.GetErrorOutput())) + if all, toLog := parseCommandOutput(terminal, commandOutput); all { + newTerminal = term.NewTerminal( + term.WithStdout(io.MultiWriter(file, terminal.Stdout())), + term.WithStderr(io.MultiWriter(file, terminal.Stderr())), + ) } else if toLog { - term.SetAllOutput(file) + newTerminal = term.NewTerminal( + term.WithStdout(file), + term.WithStderr(file), + ) } } // and return the handler (so we can close it at the end) - return handler, nil + return handler, newTerminal, nil } -func parseCommandOutput(commandOutput string) (all, log bool) { +func parseCommandOutput(terminal *term.Terminal, commandOutput string) (all, log bool) { if strings.TrimSpace(commandOutput) == "auto" { - if term.OsStdoutIsTerminal() { + if terminal.StdoutIsTerminal() { commandOutput = "log,console" } else { commandOutput = "log" diff --git a/logger_test.go b/logger_test.go index 3ed5421c..ae335e0a 100644 --- a/logger_test.go +++ b/logger_test.go @@ -99,12 +99,13 @@ func TestFileHandler(t *testing.T) { } func TestParseCommandOutput(t *testing.T) { + terminal := term.NewTerminal() tests := []struct { co string all, log bool }{ {co: "", all: false, log: false}, - {co: "auto", all: term.OsStdoutIsTerminal(), log: true}, + {co: "auto", all: terminal.StdoutIsTerminal(), log: true}, {co: "log", all: false, log: true}, {co: "console", all: false, log: false}, {co: "all", all: true, log: false}, @@ -114,7 +115,7 @@ func TestParseCommandOutput(t *testing.T) { {co: "log,a", all: false, log: true}, {co: "console,a", all: false, log: false}, - {co: " auto ", all: term.OsStdoutIsTerminal(), log: true}, + {co: " auto ", all: terminal.StdoutIsTerminal(), log: true}, {co: " all ", all: true, log: false}, {co: " log ", all: false, log: true}, {co: " console ", all: false, log: false}, @@ -124,7 +125,7 @@ func TestParseCommandOutput(t *testing.T) { for i, test := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - a, l := parseCommandOutput(test.co) + a, l := parseCommandOutput(terminal, test.co) assert.Equal(t, test.all, a, "all") assert.Equal(t, test.log, l, "log") }) diff --git a/main.go b/main.go index c621bb14..346d0483 100644 --- a/main.go +++ b/main.go @@ -131,8 +131,11 @@ func main() { terminal = term.Set(term.NewTerminal(term.WithStdout(os.Stderr), term.WithColors(!flags.noAnsi))) } if logTarget != "" && logTarget != "-" { - if closer, err := setupTargetLogger(flags, logTarget, commandOutput); err == nil { + if closer, newTerminal, err := setupTargetLogger(flags, terminal, logTarget, commandOutput); err == nil { logCloser = func() { _ = closer.Close() } + if newTerminal != nil { + terminal = term.Set(newTerminal) + } } else { // fallback to a console logger setupConsoleLogger(flags) diff --git a/schedule/handler_darwin.go b/schedule/handler_darwin.go index db6bfeb6..c832c3b6 100644 --- a/schedule/handler_darwin.go +++ b/schedule/handler_darwin.go @@ -151,7 +151,7 @@ func (h *HandlerLaunchd) DisplayJobStatus(job *Config) error { // output was not parsed, it could mean output format has changed clog.Warning("output of 'launchctl print' was either empty or using an incompatible format") } - writer := tabwriter.NewWriter(term.GetOutput(), 1, 1, 1, ' ', tabwriter.AlignRight) + writer := tabwriter.NewWriter(term.Get().Stdout(), 1, 1, 1, ' ', tabwriter.AlignRight) for _, key := range launchctlPrintKeys { key, value := presentStatus(key, status[key]) if len(value) == 0 { diff --git a/term/recording.go b/term/recording.go deleted file mode 100644 index 5fe6e86a..00000000 --- a/term/recording.go +++ /dev/null @@ -1,75 +0,0 @@ -package term - -import ( - "bytes" - "io" - "sync" - - "github.com/creativeprojects/resticprofile/util" -) - -type outputRecording struct { - lock sync.Mutex - buffer *bytes.Buffer - writer io.Writer - output, error io.Writer -} - -type RecordMode uint8 - -const ( - RecordOutput RecordMode = iota - RecordError - RecordBoth -) - -func (r *outputRecording) StartRecording(mode RecordMode) { - r.lock.Lock() - defer r.lock.Unlock() - if r.buffer != nil { - return - } - - r.buffer = new(bytes.Buffer) - r.writer = util.NewSyncWriterMutex(r.buffer, &r.lock) - - if mode != RecordError { - r.output = GetOutput() - setOutput(r.writer) - } - if mode != RecordOutput { - r.error = GetErrorOutput() - setErrorOutput(r.writer) - } -} - -func (r *outputRecording) ReadRecording() (content string) { - r.lock.Lock() - defer r.lock.Unlock() - if r.buffer != nil { - content = r.buffer.String() - r.buffer.Reset() - } - return -} - -func (r *outputRecording) StopRecording() (content string) { - r.lock.Lock() - defer r.lock.Unlock() - if r.buffer != nil { - if r.output != nil && r.writer == GetOutput() { - setOutput(r.output) - r.output = nil - } - - if r.error != nil && r.writer == GetErrorOutput() { - setErrorOutput(r.error) - r.error = nil - } - - content = r.buffer.String() - r.writer = nil - r.buffer = nil - } - return -} diff --git a/term/term.go b/term/term.go index eaf18c16..254ac2c9 100644 --- a/term/term.go +++ b/term/term.go @@ -13,13 +13,11 @@ import ( ) var ( - termOutput atomic.Pointer[io.Writer] - errorOutput atomic.Pointer[io.Writer] - colorOutput atomic.Pointer[io.Writer] - enableColors atomic.Bool - statusChannel = make(chan []string) - statusWaitChannel = make(chan chan bool) - PrintToError = false + termOutput atomic.Pointer[io.Writer] + errorOutput atomic.Pointer[io.Writer] + colorOutput atomic.Pointer[io.Writer] + enableColors atomic.Bool + PrintToError = false ) const ( @@ -35,13 +33,6 @@ func init() { } } -func truncate[E any](src []E, maxLength int) []E { - if len(src) > maxLength { - return src[:maxLength] - } - return src -} - // ReadPassword reads a password without echoing it to the terminal. // // Deprecated: use term.Terminal instead @@ -68,13 +59,6 @@ func readLine() (string, error) { return strings.TrimSpace(line), nil } -// OsStdoutIsTerminal returns true as os.Stdout is a terminal session -// -// Deprecated: use term.Terminal instead -func OsStdoutIsTerminal() bool { - return isTerminal(os.Stdout) -} - // SetOutput changes the default output for the Print* functions // // Deprecated: use term.Terminal instead @@ -131,11 +115,3 @@ func GetErrorOutput() (out io.Writer) { } return } - -// SetAllOutput changes the default and error output for the Print* functions -// -// Deprecated: use term.Terminal instead -func SetAllOutput(w io.Writer) { - SetOutput(w) - setErrorOutput(GetOutput()) -} diff --git a/util/filewriter.go b/util/filewriter.go index 1e7e11a6..9b90acaf 100644 --- a/util/filewriter.go +++ b/util/filewriter.go @@ -59,7 +59,7 @@ var ( func asyncWriterReturnToPool(data []byte) { if cap(data) == asyncWriterBlockSize { - asyncWriterBufferPool.Put(data[:0]) + asyncWriterBufferPool.Put(data[:0]) //nolint:staticcheck } } From 80ffa8244938119278d3624529fa1bb5438032fc Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 15:43:08 +0100 Subject: [PATCH 04/15] refactor: update terminal handling to use new terminal instance and remove deprecated functions --- main.go | 1 - schedule/handler_systemd.go | 17 +++--- schtasks/permission.go | 2 +- schtasks/taskscheduler.go | 2 +- term/term.go | 117 ------------------------------------ 5 files changed, 12 insertions(+), 127 deletions(-) delete mode 100644 term/term.go diff --git a/main.go b/main.go index 346d0483..4229dc0b 100644 --- a/main.go +++ b/main.go @@ -127,7 +127,6 @@ func main() { clog.Debugf("redirecting console to stderr for command %q", ctx.request.command) flags.stderr = true } - term.PrintToError = flags.stderr terminal = term.Set(term.NewTerminal(term.WithStdout(os.Stderr), term.WithColors(!flags.noAnsi))) } if logTarget != "" && logTarget != "-" { diff --git a/schedule/handler_systemd.go b/schedule/handler_systemd.go index 47b66b2a..02e3a4e0 100644 --- a/schedule/handler_systemd.go +++ b/schedule/handler_systemd.go @@ -117,7 +117,7 @@ func (h *HandlerSystemd) DisplayStatus(profileName string) error { // fail silently return nil //nolint:nilerr } - fmt.Fprintf(term.GetOutput(), "\nTimers summary\n===============\n%s\n", status) + fmt.Fprintf(term.Get().Stdout(), "\nTimers summary\n===============\n%s\n", status) return nil } @@ -424,8 +424,9 @@ func runSystemctlOnUnit(timerName, command string, unitType systemd.UnitType, si return err } if !silent { - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + terminal := term.Get() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() } err = cmd.Run() if command == systemctlStatus && cmd.ProcessState.ExitCode() == codeStatusUnitNotFound { @@ -454,8 +455,9 @@ func runJournalCtlCommand(timerName string, unitType systemd.UnitType) error { } clog.Debugf("starting command \"%s %s\"", binary, strings.Join(args, " ")) cmd := exec.CommandContext(context.TODO(), binary, args...) - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + terminal := term.Get() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() err = cmd.Run() fmt.Println("") return err @@ -470,8 +472,9 @@ func runSystemctlReload(unitType systemd.UnitType) error { if err != nil { return err } - cmd.Stdout = term.GetOutput() - cmd.Stderr = term.GetErrorOutput() + terminal := term.Get() + cmd.Stdout = terminal.Stdout() + cmd.Stderr = terminal.Stderr() err = cmd.Run() if err != nil { return err diff --git a/schtasks/permission.go b/schtasks/permission.go index 3f212754..5004feb6 100644 --- a/schtasks/permission.go +++ b/schtasks/permission.go @@ -40,7 +40,7 @@ func userCredentials() (string, string, error) { fmt.Printf("\nCreating task for user %s\n", userName) fmt.Printf("Task Scheduler requires your Windows password to validate the task: ") - userPassword, err = term.ReadPassword() + userPassword, err = term.Get().ReadPassword() if err != nil { return "", "", err } diff --git a/schtasks/taskscheduler.go b/schtasks/taskscheduler.go index 06e1871c..3a44bea5 100644 --- a/schtasks/taskscheduler.go +++ b/schtasks/taskscheduler.go @@ -125,7 +125,7 @@ func Status(title, subtitle string) error { if len(info) < 2 { return ErrNotRegistered } - writer := tabwriter.NewWriter(term.GetOutput(), 2, 2, 2, ' ', tabwriter.AlignRight) + writer := tabwriter.NewWriter(term.Get().Stdout(), 2, 2, 2, ' ', tabwriter.AlignRight) fmt.Fprintf(writer, "Task:\t %s\n", getFirstField(info, "TaskName")) fmt.Fprintf(writer, "User:\t %s\n", getFirstField(info, "Run As User")) fmt.Fprintf(writer, "Logon Mode:\t %s\n", getFirstField(info, "Logon Mode")) diff --git a/term/term.go b/term/term.go deleted file mode 100644 index 254ac2c9..00000000 --- a/term/term.go +++ /dev/null @@ -1,117 +0,0 @@ -package term - -import ( - "bufio" - "fmt" - "io" - "os" - "strings" - "sync/atomic" - - "github.com/creativeprojects/resticprofile/util" - "golang.org/x/term" -) - -var ( - termOutput atomic.Pointer[io.Writer] - errorOutput atomic.Pointer[io.Writer] - colorOutput atomic.Pointer[io.Writer] - enableColors atomic.Bool - PrintToError = false -) - -const ( - StatusFPS = 10 -) - -func init() { - enableColors.Store(true) - // must be last - { - setOutput(os.Stdout) - setErrorOutput(os.Stderr) - } -} - -// ReadPassword reads a password without echoing it to the terminal. -// -// Deprecated: use term.Terminal instead -func ReadPassword() (string, error) { - stdin := fdToInt(os.Stdin.Fd()) - if !term.IsTerminal(stdin) { - return readLine() - } - line, err := term.ReadPassword(stdin) - _, _ = fmt.Fprintln(os.Stderr) - if err != nil { - return "", fmt.Errorf("failed to read password: %w", err) - } - return string(line), nil -} - -// readLine reads some input -func readLine() (string, error) { - buf := bufio.NewReader(os.Stdin) - line, err := buf.ReadString('\n') - if err != nil { - return "", fmt.Errorf("failed to read line: %w", err) - } - return strings.TrimSpace(line), nil -} - -// SetOutput changes the default output for the Print* functions -// -// Deprecated: use term.Terminal instead -func SetOutput(w io.Writer) { - if w == os.Stdout && isTerminal(os.Stdout) { - setOutput(os.Stdout) - } else { - setOutput(util.NewSyncWriter(w)) - } -} - -func setOutput(w io.Writer) { - if w == nil { - w = io.Discard - } - termOutput.Store(&w) - colorOutput.Store(nil) -} - -// GetOutput returns the default output of the Print* functions -// -// Deprecated: use term.Terminal instead -func GetOutput() (out io.Writer) { - if v := termOutput.Load(); v != nil { - out = *v - } - return -} - -// SetErrorOutput changes the error output for the Print* functions -// -// Deprecated: use term.Terminal instead -func SetErrorOutput(w io.Writer) { - if w == os.Stderr && isTerminal(os.Stderr) { - setErrorOutput(os.Stderr) - } else { - setErrorOutput(util.NewSyncWriter(w)) - } -} - -func setErrorOutput(w io.Writer) { - if w == nil { - w = io.Discard - } - errorOutput.Store(&w) -} - -// GetErrorOutput returns the error output of the Print* functions -// -// Deprecated: use term.Terminal instead -func GetErrorOutput() (out io.Writer) { - if v := errorOutput.Load(); v != nil { - out = *v - } - return -} From 577c44261231478b13c37a6dbb4f1a4cb39545dc Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 16:13:07 +0100 Subject: [PATCH 05/15] fix: update nilReader to return io.EOF and improve error handling in filewriter and reader --- term/nil_reader.go | 4 +++- util/filewriter.go | 14 ++++++++------ util/filewriter_test.go | 13 +++++++++++++ util/reader.go | 2 +- util/reader_test.go | 10 +++------- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/term/nil_reader.go b/term/nil_reader.go index bfea236a..4ae07109 100644 --- a/term/nil_reader.go +++ b/term/nil_reader.go @@ -1,7 +1,9 @@ package term +import "io" + type nilReader struct{} func (nilReader) Read(p []byte) (int, error) { - return 0, nil + return 0, io.EOF } diff --git a/util/filewriter.go b/util/filewriter.go index 9b90acaf..fa939350 100644 --- a/util/filewriter.go +++ b/util/filewriter.go @@ -1,6 +1,7 @@ package util import ( + "errors" "fmt" "io" "os" @@ -86,7 +87,7 @@ func NewAsyncFileWriter(filename string, options ...asyncFileWriterOption) (io.W closeFile := func() { if file != nil { - lastError = file.Close() + lastError = errors.Join(lastError, file.Close()) file = nil } } @@ -110,11 +111,11 @@ func NewAsyncFileWriter(filename string, options ...asyncFileWriterOption) (io.W } else { buffer = buffer[:0] } - } - if w.keepOpen { - _ = file.Sync() - } else { - closeFile() + if w.keepOpen { + _ = file.Sync() + } else { + closeFile() + } } } @@ -195,6 +196,7 @@ type asyncFileWriter struct { perm os.FileMode keepOpen bool interval time.Duration + closeOnce sync.Once } func (w *asyncFileWriter) Close() error { diff --git a/util/filewriter_test.go b/util/filewriter_test.go index 3c4cf690..93104f6c 100644 --- a/util/filewriter_test.go +++ b/util/filewriter_test.go @@ -12,6 +12,19 @@ import ( "github.com/stretchr/testify/require" ) +// func TestAsyncFileWriterDoubleClose(t *testing.T) { +// dir := t.TempDir() +// filename := filepath.Join(dir, "test.log") + +// w, err := NewAsyncFileWriter(filename) +// require.NoError(t, err) + +// err = w.Close() +// assert.NoError(t, err) +// err = w.Close() +// assert.NoError(t, err) +// } + func TestAsyncFileWriterBasicWrite(t *testing.T) { dir := t.TempDir() filename := filepath.Join(dir, "test.log") diff --git a/util/reader.go b/util/reader.go index 8a969c62..1f22c83b 100644 --- a/util/reader.go +++ b/util/reader.go @@ -33,7 +33,7 @@ type filterReadCloser struct { close CloseFunc } -func (c filterReadCloser) Close() error { +func (c *filterReadCloser) Close() error { defer func() { c.close = nil c.read = nil diff --git a/util/reader_test.go b/util/reader_test.go index 39f98ffc..244fec3f 100644 --- a/util/reader_test.go +++ b/util/reader_test.go @@ -97,8 +97,6 @@ func TestFilterReadCloserClose(t *testing.T) { } func TestFilterReadCloserCloseCalledTwice(t *testing.T) { - // Close() uses a value receiver so the nil-guard mutation does not persist; - // calling Close twice invokes the close func both times. calls := 0 inner := strings.NewReader("") rc := NewFilterReadCloser(inner.Read, func() error { @@ -108,7 +106,7 @@ func TestFilterReadCloserCloseCalledTwice(t *testing.T) { _ = rc.Close() _ = rc.Close() - assert.Equal(t, 2, calls) + assert.Equal(t, 1, calls) } func TestFilterReadCloserCloseError(t *testing.T) { @@ -138,8 +136,6 @@ func TestFilterReadCloserNilCloseFunc(t *testing.T) { } func TestFilterReadCloserReadAfterClose(t *testing.T) { - // Close() uses a value receiver so read is NOT cleared on the original; - // reads after close continue to work via the underlying reader. inner := strings.NewReader("data") rc := NewFilterReadCloser(inner.Read, func() error { return nil }) @@ -147,8 +143,8 @@ func TestFilterReadCloserReadAfterClose(t *testing.T) { buf := make([]byte, 4) n, err := rc.Read(buf) - assert.NoError(t, err) - assert.Equal(t, "data", string(buf[:n])) + assert.Error(t, err) + assert.Equal(t, 0, n) } // syncReader tests From 561008a631f227e19e41e6d42976e174c33251c0 Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 21:34:54 +0100 Subject: [PATCH 06/15] refactor: update terminal handling to use atomic pointers and improve concurrency safety; remove deprecated reader functions --- term/default.go | 14 +- term/default_test.go | 43 ++++++ util/filewriter.go | 21 +-- util/filewriter_test.go | 24 ++-- util/reader.go | 94 ------------- util/reader_test.go | 289 ---------------------------------------- 6 files changed, 75 insertions(+), 410 deletions(-) create mode 100644 term/default_test.go delete mode 100644 util/reader.go delete mode 100644 util/reader_test.go diff --git a/term/default.go b/term/default.go index 4e0bcb08..cf98883b 100644 --- a/term/default.go +++ b/term/default.go @@ -2,24 +2,24 @@ package term import ( "os" + "sync/atomic" "golang.org/x/term" ) -var defaultTerminal *Terminal +var defaultTerminal atomic.Pointer[Terminal] // Get returns the default terminal. It will be initialized on first use with NewTerminal() if not set before. +// Ideally we should pass the terminal where we need it, but Get() can be safely used until the refactoring is finished. func Get() *Terminal { - if defaultTerminal == nil { - defaultTerminal = NewTerminal() - } - return defaultTerminal + defaultTerminal.CompareAndSwap(nil, NewTerminal()) + return defaultTerminal.Load() } // Set stores the default terminal, and returns the terminal reference for chaining. func Set(t *Terminal) *Terminal { - defaultTerminal = t - return defaultTerminal + defaultTerminal.Store(t) + return t } // Size returns the width and height of the terminal session diff --git a/term/default_test.go b/term/default_test.go new file mode 100644 index 00000000..f34eedfa --- /dev/null +++ b/term/default_test.go @@ -0,0 +1,43 @@ +package term + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetTerminalSingleton(t *testing.T) { + total := 10 + wg := new(sync.WaitGroup) + terminals := make([]*Terminal, total) + for i := range total { + wg.Go(func() { + terminals[i] = Get() + }) + } + wg.Wait() + + for i := range total { + assert.NotNil(t, terminals[i]) + assert.Same(t, terminals[0], terminals[i]) + } +} + +func TestSetAndGetTerminalSingleton(t *testing.T) { + terminal := Set(NewTerminal()) + total := 10 + wg := new(sync.WaitGroup) + terminals := make([]*Terminal, total) + for i := range total { + wg.Go(func() { + terminals[i] = Get() + }) + } + wg.Wait() + + for i := range total { + assert.NotNil(t, terminals[i]) + assert.Same(t, terminal, terminals[i]) + } +} diff --git a/util/filewriter.go b/util/filewriter.go index fa939350..a462022c 100644 --- a/util/filewriter.go +++ b/util/filewriter.go @@ -2,10 +2,10 @@ package util import ( "errors" - "fmt" "io" "os" "sync" + "sync/atomic" "time" "github.com/creativeprojects/resticprofile/platform" @@ -188,6 +188,8 @@ func NewAsyncFileWriter(filename string, options ...asyncFileWriterOption) (io.W return w, lastError } +var errAlreadyClosed = errors.New("file writer already closed") + type asyncFileWriter struct { done, flush chan chan error data chan []byte @@ -196,28 +198,31 @@ type asyncFileWriter struct { perm os.FileMode keepOpen bool interval time.Duration - closeOnce sync.Once + closed atomic.Bool } func (w *asyncFileWriter) Close() error { + if !w.closed.CompareAndSwap(false, true) { + return errAlreadyClosed + } req := make(chan error) w.done <- req return <-req } func (w *asyncFileWriter) Flush() error { + if w.closed.Load() { + return errAlreadyClosed + } req := make(chan error) w.flush <- req return <-req } func (w *asyncFileWriter) Write(data []byte) (n int, err error) { - defer func() { - msg := recover() - if msg != nil { - err = fmt.Errorf("panic: %v", msg) - } - }() + if w.closed.Load() { + return 0, errAlreadyClosed + } var buffer []byte if len(data) <= asyncWriterBlockSize { buffer = asyncWriterBufferPool.Get().([]byte)[:len(data)] diff --git a/util/filewriter_test.go b/util/filewriter_test.go index 93104f6c..2d968642 100644 --- a/util/filewriter_test.go +++ b/util/filewriter_test.go @@ -12,18 +12,18 @@ import ( "github.com/stretchr/testify/require" ) -// func TestAsyncFileWriterDoubleClose(t *testing.T) { -// dir := t.TempDir() -// filename := filepath.Join(dir, "test.log") - -// w, err := NewAsyncFileWriter(filename) -// require.NoError(t, err) - -// err = w.Close() -// assert.NoError(t, err) -// err = w.Close() -// assert.NoError(t, err) -// } +func TestAsyncFileWriterDoubleClose(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "test.log") + + w, err := NewAsyncFileWriter(filename) + require.NoError(t, err) + + err = w.Close() + assert.NoError(t, err) + err = w.Close() + assert.Error(t, err) +} func TestAsyncFileWriterBasicWrite(t *testing.T) { dir := t.TempDir() diff --git a/util/reader.go b/util/reader.go deleted file mode 100644 index 1f22c83b..00000000 --- a/util/reader.go +++ /dev/null @@ -1,94 +0,0 @@ -package util - -import ( - "io" - "sync" -) - -// ReadFunc is the callback in NewFilterReader & NewFilterReadCloser for io.Reader.Read. -type ReadFunc func(bytes []byte) (n int, err error) - -type filterReader struct { - read ReadFunc -} - -func (f *filterReader) Read(bytes []byte) (n int, err error) { - if f.read == nil { - err = io.EOF - return - } - return f.read(bytes) -} - -// NewFilterReader creates a new io.Reader redirects all read calls to ReadFunc -func NewFilterReader(read ReadFunc) io.Reader { - return &filterReader{read} -} - -// CloseFunc is the callback in NewFilterReadCloser for io.Closer.Close. It is guaranteed to be called once. -type CloseFunc func() (err error) - -type filterReadCloser struct { - filterReader - close CloseFunc -} - -func (c *filterReadCloser) Close() error { - defer func() { - c.close = nil - c.read = nil - }() - if c.close == nil { - return nil - } - return c.close() -} - -// NewFilterReadCloser creates a new io.ReadCloser redirects all calls to ReadFunc and CloseFunc -func NewFilterReadCloser(read ReadFunc, closer CloseFunc) io.ReadCloser { - return &filterReadCloser{*NewFilterReader(read).(*filterReader), closer} -} - -// NewSyncReader creates a new reader that is safe for concurrent use -func NewSyncReader[R io.Reader](reader R) SyncReader[R] { - mutex := new(sync.Mutex) - return NewSyncReaderMutex(reader, mutex, mutex) -} - -// NewSyncReaderMutex creates a new reader that is safe for concurrent use and synced with the specified sync.Mutex -func NewSyncReaderMutex[R io.Reader](reader R, mutex, closeMutex *sync.Mutex) SyncReader[R] { - return &syncReader[R]{reader: reader, mutex: mutex, closeMutex: closeMutex} -} - -// SyncReader implements an io.ReadCloser that is safe for concurrent use -type SyncReader[R io.Reader] interface { - io.ReadCloser - // Locked provides sync.Mutex locked access to the underlying reader - Locked(fn func(R) error) error -} - -type syncReader[R io.Reader] struct { - reader io.Reader - mutex, closeMutex *sync.Mutex -} - -func (w *syncReader[R]) Locked(fn func(reader R) error) error { - w.mutex.Lock() - defer w.mutex.Unlock() - return fn(w.reader.(R)) -} - -func (w *syncReader[R]) Read(p []byte) (n int, err error) { - w.mutex.Lock() - defer w.mutex.Unlock() - return w.reader.Read(p) -} - -func (w *syncReader[R]) Close() (err error) { - w.closeMutex.Lock() - defer w.closeMutex.Unlock() - if f, ok := w.reader.(io.Closer); ok { - err = f.Close() - } - return -} diff --git a/util/reader_test.go b/util/reader_test.go deleted file mode 100644 index 244fec3f..00000000 --- a/util/reader_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package util - -import ( - "errors" - "io" - "strings" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// readCloser wraps a reader and tracks whether Close was called. -type trackingReadCloser struct { - io.Reader - closed bool - err error -} - -func (t *trackingReadCloser) Close() error { - t.closed = true - return t.err -} - -// filterReader & filterReadCloser tests - -func TestFilterReaderRead(t *testing.T) { - inner := strings.NewReader("hello") - r := NewFilterReader(inner.Read) - - buf := make([]byte, 5) - n, err := r.Read(buf) - assert.NoError(t, err) - assert.Equal(t, 5, n) - assert.Equal(t, "hello", string(buf)) -} - -func TestFilterReaderReadEOF(t *testing.T) { - inner := strings.NewReader("") - r := NewFilterReader(inner.Read) - - buf := make([]byte, 4) - _, err := r.Read(buf) - assert.Equal(t, io.EOF, err) -} - -func TestFilterReaderNilReadFunc(t *testing.T) { - r := NewFilterReader(nil) - - buf := make([]byte, 4) - n, err := r.Read(buf) - assert.Equal(t, io.EOF, err) - assert.Equal(t, 0, n) -} - -func TestFilterReaderReadAll(t *testing.T) { - inner := strings.NewReader("read all") - r := NewFilterReader(inner.Read) - - data, err := io.ReadAll(r) - require.NoError(t, err) - assert.Equal(t, "read all", string(data)) -} - -func TestFilterReaderPropagatesError(t *testing.T) { - sentinel := errors.New("read error") - r := NewFilterReader(func([]byte) (int, error) { return 0, sentinel }) - - buf := make([]byte, 4) - _, err := r.Read(buf) - assert.Equal(t, sentinel, err) -} - -func TestFilterReadCloserRead(t *testing.T) { - inner := strings.NewReader("world") - rc := NewFilterReadCloser(inner.Read, func() error { return nil }) - - buf := make([]byte, 5) - n, err := rc.Read(buf) - assert.NoError(t, err) - assert.Equal(t, 5, n) - assert.Equal(t, "world", string(buf)) -} - -func TestFilterReadCloserClose(t *testing.T) { - closed := false - inner := strings.NewReader("") - rc := NewFilterReadCloser(inner.Read, func() error { - closed = true - return nil - }) - - err := rc.Close() - assert.NoError(t, err) - assert.True(t, closed) -} - -func TestFilterReadCloserCloseCalledTwice(t *testing.T) { - calls := 0 - inner := strings.NewReader("") - rc := NewFilterReadCloser(inner.Read, func() error { - calls++ - return nil - }) - - _ = rc.Close() - _ = rc.Close() - assert.Equal(t, 1, calls) -} - -func TestFilterReadCloserCloseError(t *testing.T) { - sentinel := errors.New("close error") - inner := strings.NewReader("") - rc := NewFilterReadCloser(inner.Read, func() error { return sentinel }) - - err := rc.Close() - assert.Equal(t, sentinel, err) -} - -func TestFilterReadCloserNilReadFunc(t *testing.T) { - rc := NewFilterReadCloser(nil, func() error { return nil }) - - buf := make([]byte, 4) - n, err := rc.Read(buf) - assert.Equal(t, io.EOF, err) - assert.Equal(t, 0, n) -} - -func TestFilterReadCloserNilCloseFunc(t *testing.T) { - inner := strings.NewReader("") - rc := NewFilterReadCloser(inner.Read, nil) - - err := rc.Close() - assert.NoError(t, err) -} - -func TestFilterReadCloserReadAfterClose(t *testing.T) { - inner := strings.NewReader("data") - rc := NewFilterReadCloser(inner.Read, func() error { return nil }) - - require.NoError(t, rc.Close()) - - buf := make([]byte, 4) - n, err := rc.Read(buf) - assert.Error(t, err) - assert.Equal(t, 0, n) -} - -// syncReader tests - -func TestNewSyncReaderRead(t *testing.T) { - inner := strings.NewReader("hello world") - sr := NewSyncReader[io.Reader](inner) - - buf := make([]byte, 11) - n, err := sr.Read(buf) - assert.NoError(t, err) - assert.Equal(t, 11, n) - assert.Equal(t, "hello world", string(buf)) -} - -func TestNewSyncReaderReadEOF(t *testing.T) { - inner := strings.NewReader("") - sr := NewSyncReader[io.Reader](inner) - - buf := make([]byte, 8) - _, err := sr.Read(buf) - assert.Equal(t, io.EOF, err) -} - -func TestNewSyncReaderCloseWithCloser(t *testing.T) { - rc := &trackingReadCloser{Reader: strings.NewReader("data")} - sr := NewSyncReader(rc) - - err := sr.Close() - assert.NoError(t, err) - assert.True(t, rc.closed) -} - -func TestNewSyncReaderCloseNonCloser(t *testing.T) { - // strings.Reader does not implement io.Closer; Close should be a no-op - inner := strings.NewReader("data") - sr := NewSyncReader[io.Reader](inner) - - err := sr.Close() - assert.NoError(t, err) -} - -func TestNewSyncReaderCloseError(t *testing.T) { - closeErr := errors.New("close failed") - rc := &trackingReadCloser{Reader: strings.NewReader("data"), err: closeErr} - sr := NewSyncReader(rc) - - err := sr.Close() - assert.Equal(t, closeErr, err) -} - -func TestNewSyncReaderLocked(t *testing.T) { - inner := strings.NewReader("locked") - sr := NewSyncReader(inner) - - var seen string - err := sr.Locked(func(r *strings.Reader) error { - buf := make([]byte, 6) - n, err := r.Read(buf) - seen = string(buf[:n]) - return err - }) - assert.NoError(t, err) - assert.Equal(t, "locked", seen) -} - -func TestNewSyncReaderLockedPropagatesError(t *testing.T) { - inner := strings.NewReader("data") - sr := NewSyncReader(inner) - - sentinel := errors.New("locked error") - err := sr.Locked(func(_ *strings.Reader) error { - return sentinel - }) - assert.Equal(t, sentinel, err) -} - -func TestNewSyncReaderMutex(t *testing.T) { - inner := strings.NewReader("shared mutex") - mutex := new(sync.Mutex) - sr := NewSyncReaderMutex[io.Reader](inner, mutex, mutex) - - buf := make([]byte, 12) - n, err := sr.Read(buf) - assert.NoError(t, err) - assert.Equal(t, 12, n) - assert.Equal(t, "shared mutex", string(buf)) -} - -func TestNewSyncReaderMutexSeparateCloseMutex(t *testing.T) { - rc := &trackingReadCloser{Reader: strings.NewReader("separate")} - readMutex := new(sync.Mutex) - closeMutex := new(sync.Mutex) - sr := NewSyncReaderMutex(rc, readMutex, closeMutex) - - err := sr.Close() - assert.NoError(t, err) - assert.True(t, rc.closed) -} - -func TestNewSyncReaderConcurrentReads(t *testing.T) { - // Writes a large enough buffer so concurrent goroutines don't exhaust it - // immediately; we just want to confirm no races / panics occur. - data := strings.Repeat("x", 1024) - inner := strings.NewReader(data) - sr := NewSyncReader[io.Reader](inner) - - var wg sync.WaitGroup - for range 10 { - wg.Go(func() { - buf := make([]byte, 32) - _, _ = sr.Read(buf) - }) - } - wg.Wait() -} - -func TestNewSyncReaderConcurrentReadAndClose(t *testing.T) { - rc := &trackingReadCloser{Reader: strings.NewReader(strings.Repeat("y", 512))} - sr := NewSyncReader(rc) - - var wg sync.WaitGroup - for range 5 { - wg.Go(func() { - buf := make([]byte, 16) - _, _ = sr.Read(buf) - }) - } - wg.Go(func() { - _ = sr.Close() - }) - wg.Wait() -} - -func TestNewSyncReaderReadAll(t *testing.T) { - inner := strings.NewReader("read all content") - sr := NewSyncReader[io.Reader](inner) - - content, err := io.ReadAll(sr) - require.NoError(t, err) - assert.Equal(t, "read all content", string(content)) -} From 9a31617faf73e4399ed6c4f982fb54f49c4f8df1 Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 21:42:17 +0100 Subject: [PATCH 07/15] refactor: remove TestAsyncFileWriterFilePerm test case to streamline file writer tests --- util/filewriter_test.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/util/filewriter_test.go b/util/filewriter_test.go index 2d968642..47cf691e 100644 --- a/util/filewriter_test.go +++ b/util/filewriter_test.go @@ -133,21 +133,6 @@ func TestAsyncFileWriterAppendMode(t *testing.T) { assert.Equal(t, "firstsecond", string(content)) } -func TestAsyncFileWriterFilePerm(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename, WithAsyncFilePerm(0600)) - require.NoError(t, err) - _, err = w.Write([]byte("data")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - info, err := os.Stat(filename) - require.NoError(t, err) - assert.Equal(t, os.FileMode(0600), info.Mode().Perm()) -} - func TestAsyncFileWriterCustomAppendFunc(t *testing.T) { dir := t.TempDir() filename := filepath.Join(dir, "test.log") From 43e324de30978b5f6e770a6f3cd26d9107a83e15 Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 21:52:56 +0100 Subject: [PATCH 08/15] refactor: update codecov configuration and remove unused shell utility functions --- codecov.yml | 1 + shell/util.go | 45 ----------------------------------------- term/nil_reader_test.go | 16 +++++++++++++++ 3 files changed, 17 insertions(+), 45 deletions(-) delete mode 100644 shell/util.go create mode 100644 term/nil_reader_test.go diff --git a/codecov.yml b/codecov.yml index 98e325fa..95e26d52 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,6 +4,7 @@ ignore: - priority/check/ - util/test_executable/ - term/remote.go + - term/test/ - deprecation.go - logger.go - systemd.go diff --git a/shell/util.go b/shell/util.go deleted file mode 100644 index 700b6a05..00000000 --- a/shell/util.go +++ /dev/null @@ -1,45 +0,0 @@ -package shell - -import ( - "bufio" - "bytes" - "io" - - "github.com/creativeprojects/resticprofile/platform" -) - -var ( - bogusPrefix = []byte("\r\x1b[2K") -) - -func LineOutputFilter(output io.Writer, included func(line []byte) bool) io.WriteCloser { - eol := []byte(platform.LineSeparator) - - reader, writer := io.Pipe() - - go func() { - var err error - defer func() { - _ = reader.CloseWithError(err) - }() - - scanner := bufio.NewScanner(reader) - for err == nil && scanner.Scan() { - line := bytes.TrimPrefix(scanner.Bytes(), bogusPrefix) - if !included(line) { - continue - } - if err == nil { - _, err = output.Write(line) - } - if err == nil { - _, err = output.Write(eol) - } - } - if err == nil { - err = scanner.Err() - } - }() - - return writer -} diff --git a/term/nil_reader_test.go b/term/nil_reader_test.go new file mode 100644 index 00000000..0d7d1895 --- /dev/null +++ b/term/nil_reader_test.go @@ -0,0 +1,16 @@ +package term + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadingFromNilReader(t *testing.T) { + r := new(nilReader) + buffer, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, 0, len(buffer)) +} From f9aaed5e492e288821fde60903c2738a7807c8f4 Mon Sep 17 00:00:00 2001 From: Fred Date: Sun, 5 Apr 2026 21:57:47 +0100 Subject: [PATCH 09/15] test: add TestContextWithTerminal to validate terminal assignment --- context_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/context_test.go b/context_test.go index cecd55a7..9e6e462f 100644 --- a/context_test.go +++ b/context_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/creativeprojects/resticprofile/config" + "github.com/creativeprojects/resticprofile/term" "github.com/stretchr/testify/assert" ) @@ -56,6 +57,13 @@ func TestContextWithGroup(t *testing.T) { assert.NotEmpty(t, ctx.request.command) } +func TestContextWithTerminal(t *testing.T) { + terminal := term.NewTerminal() + ctx := &Context{} + ctx = ctx.WithTerminal(terminal) + assert.Same(t, terminal, ctx.terminal) +} + func TestContextWithProfile(t *testing.T) { ctx := &Context{ request: Request{ From 28366c694cbb91cb096670a9f319466386f46d02 Mon Sep 17 00:00:00 2001 From: Fred Date: Mon, 6 Apr 2026 16:04:46 +0100 Subject: [PATCH 10/15] refactor: enhance terminal handling with new options and improve output duplication --- Makefile | 4 +- commands.go | 6 ++- logger.go | 20 ++++---- main.go | 101 ++++++++++++++++++++++++---------------- term/terminal.go | 50 +++++++++++++++----- term/terminal_option.go | 52 +++++++++++++++++---- 6 files changed, 156 insertions(+), 77 deletions(-) diff --git a/Makefile b/Makefile index f563cee1..4c771dcc 100644 --- a/Makefile +++ b/Makefile @@ -178,9 +178,9 @@ build-windows: prepare_build ## Build the binary for Windows build-all: build-mac build-linux build-pi build-windows ## Build the binary for all platforms -test: prepare_test ## Run unit tests +test: $(GOBIN)/gotestsum prepare_test ## Run unit tests @echo "[*] $@" - $(GOTEST) $(TESTS) + $(GOBIN)/gotestsum $(TESTS) test-ci: $(GOBIN)/gotestsum prepare_test ## Run unit tests with coverage (for CI) @echo "[*] $@" diff --git a/commands.go b/commands.go index 5e638b20..d426a05b 100644 --- a/commands.go +++ b/commands.go @@ -22,6 +22,7 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/platform" "github.com/creativeprojects/resticprofile/remote" + "github.com/creativeprojects/resticprofile/util/ansi" "github.com/creativeprojects/resticprofile/win" "github.com/distatus/battery" ) @@ -338,9 +339,10 @@ func randomKey(ctx commandContext) error { func testElevationCommand(ctx commandContext) error { if ctx.flags.isChild { client := remote.NewClient(ctx.flags.parentPort) - ctx.terminal.Print("first line", "\n") - ctx.terminal.Println("second", "one") + ctx.terminal.Print("simple line", "\n") ctx.terminal.Printf("value = %d\n", 11) + ctx.terminal.Println(ansi.Bold("in bold")) + clog.Info("log content") err := client.Done() if err != nil { return err diff --git a/logger.go b/logger.go index 816609b4..ec100138 100644 --- a/logger.go +++ b/logger.go @@ -49,7 +49,7 @@ func setupRemoteLogger(flags commandLineFlags, client *remote.Client) { clog.SetDefaultLogger(logger) } -func setupTargetLogger(flags commandLineFlags, terminal *term.Terminal, logTarget, commandOutput string) (io.Closer, *term.Terminal, error) { +func setupTargetLogger(flags commandLineFlags, terminal *term.Terminal, logTarget, commandOutput string) (io.Closer, []term.TerminalOption, error) { var ( handler LogCloser file io.Writer @@ -70,24 +70,24 @@ func setupTargetLogger(flags commandLineFlags, terminal *term.Terminal, logTarge // default logger added with level filtering clog.SetDefaultLogger(logger) - var newTerminal *term.Terminal - + var terminalOptions []term.TerminalOption // also redirect all terminal output if file != nil { if all, toLog := parseCommandOutput(terminal, commandOutput); all { - newTerminal = term.NewTerminal( - term.WithStdout(io.MultiWriter(file, terminal.Stdout())), - term.WithStderr(io.MultiWriter(file, terminal.Stderr())), - ) + clog.Debugf("sending a copy of the console logs to %q", logTarget) + terminalOptions = []term.TerminalOption{ + term.WithStdoutCopy(file), + term.WithStderrCopy(file), + } } else if toLog { - newTerminal = term.NewTerminal( + terminalOptions = []term.TerminalOption{ term.WithStdout(file), term.WithStderr(file), - ) + } } } // and return the handler (so we can close it at the end) - return handler, newTerminal, nil + return handler, terminalOptions, nil } func parseCommandOutput(terminal *term.Terminal, commandOutput string) (all, log bool) { diff --git a/main.go b/main.go index 4229dc0b..e8f72c7a 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,9 @@ func main() { // run shutdown hooks just before returning an exit code defer shutdown.RunHooks() + // prepare a default terminal before loading the terminal configuration terminal := term.Set(term.NewTerminal()) + terminalOptions := make([]term.TerminalOption, 0, 5) args := os.Args[1:] _, flags, flagErr := loadFlags(args) @@ -68,9 +70,10 @@ func main() { return } - // Configure terminal color + // Now that we have loaded the flags, configure terminal color if flags.noAnsi || flags.theme == "none" { - terminal = term.Set(term.NewTerminal(term.WithColors(false))) // disable colors + terminalOptions = append(terminalOptions, term.WithColors(false)) // disable colors + terminal = term.Set(term.NewTerminal(terminalOptions...)) } if flags.wait { @@ -116,45 +119,51 @@ func main() { // also redirect the terminal through the client remoteTerm := term.NewRemoteTerm(client) - terminal = term.Set(term.NewTerminal(term.WithStdout(remoteTerm), term.WithStderr(remoteTerm), term.WithColors(!flags.noAnsi))) - } else { - logTarget, commandOutput := "", "" - if ctx != nil { - logTarget = ctx.logTarget - commandOutput = ctx.commandOutput - if ctx.request.command == constants.CommandCat || - ctx.request.command == constants.CommandDump { - clog.Debugf("redirecting console to stderr for command %q", ctx.request.command) - flags.stderr = true - } - terminal = term.Set(term.NewTerminal(term.WithStdout(os.Stderr), term.WithColors(!flags.noAnsi))) + terminalOptions = append(terminalOptions, term.WithStdout(remoteTerm), term.WithStderr(remoteTerm)) + + return + } + logTarget, commandOutput := "", "" + if ctx != nil { + logTarget = ctx.logTarget + commandOutput = ctx.commandOutput + if ctx.request.command == constants.CommandCat || + ctx.request.command == constants.CommandDump { + clog.Debugf("redirecting console to stderr for command %q", ctx.request.command) + flags.stderr = true } - if logTarget != "" && logTarget != "-" { - if closer, newTerminal, err := setupTargetLogger(flags, terminal, logTarget, commandOutput); err == nil { - logCloser = func() { _ = closer.Close() } - if newTerminal != nil { - terminal = term.Set(newTerminal) - } - } else { - // fallback to a console logger - setupConsoleLogger(flags) - clog.Errorf("cannot open log target: %s", err) - } - } else { - // use the console logger - setupConsoleLogger(flags) + if flags.stderr { + terminalOptions = append(terminalOptions, term.WithStdout(os.Stderr)) } } + if logTarget != "" && logTarget != "-" { + if closer, options, err := setupTargetLogger(flags, terminal, logTarget, commandOutput); err == nil { + logCloser = func() { _ = closer.Close() } + terminalOptions = append(terminalOptions, options...) + return + } + // fallback to a console logger + setupConsoleLogger(flags) + clog.Errorf("cannot open log target: %s", err) + + return + } + // use the console logger + setupConsoleLogger(flags) + return } + // refresh terminal with new config + terminal = term.Set(term.NewTerminal(terminalOptions...)) + // keep this one last if possible (so it will be first at the end) defer showPanicData() banner() if flags.remote != "" { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) closeFS, remoteParameters, err := setupRemoteConfiguration(ctx, flags.remote) cancel() if err != nil { @@ -184,15 +193,22 @@ func main() { command: flags.resticArgs[0], arguments: flags.resticArgs[1:], }, - terminal: terminal, + logTarget: flags.log, + commandOutput: flags.commandOutput, } + // try to load the config and setup logging for own command cfg, global, err := loadConfig(flags, true) if err == nil { ctx = ctx.WithConfig(cfg, global) } closeLogger := setupLogging(ctx) - defer closeLogger() + shutdown.AddHook(closeLogger) + + // refresh terminal with new logging config + terminal = term.Set(term.NewTerminal(terminalOptions...)) + ctx = ctx.WithTerminal(terminal) + err = ownCommands.Run(ctx) if err != nil { clog.Error(err) @@ -210,13 +226,15 @@ func main() { // Load the now mandatory configuration and setup logging (before returning an error) ctx, err := loadContext(flags) closeLogger := setupLogging(ctx) - defer closeLogger() + shutdown.AddHook(closeLogger) if err != nil { clog.Error(err) exitCode = constants.ExitGeneralError return } + // refresh terminal with new config + terminal = term.Set(term.NewTerminal(terminalOptions...)) ctx = ctx.WithTerminal(terminal) // check if we're running on battery @@ -236,14 +254,14 @@ func main() { } } // and stop at the end - defer func() { + shutdown.AddHook(func() { if caffeinate != nil { err = caffeinate.Stop() if err != nil { clog.Error(err) } } - }() + }) // Check memory pressure if ctx.global.MinMemory > 0 { @@ -429,14 +447,15 @@ func free() uint64 { func showPanicData() { if r := recover(); r != nil { + terminal := term.Get() message := ` -=============================================================== -uh-oh! resticprofile crashed miserably :-( -Can you please open an issue on github including these details: -=============================================================== +================================================================= + uh-oh! resticprofile crashed miserably :-( + Can you please open an issue on github including these details: +================================================================= ` - fmt.Fprint(os.Stderr, message) - w := tabwriter.NewWriter(os.Stderr, 0, 0, 3, ' ', 0) + fmt.Fprint(terminal.Stderr(), message) + w := tabwriter.NewWriter(terminal.Stderr(), 0, 0, 3, ' ', 0) _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "os", runtime.GOOS) _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "arch", runtime.GOARCH) _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "version", version) @@ -446,6 +465,6 @@ Can you please open an issue on github including these details: _, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "error", r) _, _ = fmt.Fprintf(w, "\t%s:\n%s\n", "stack", getStack(3)) // skip calls to getStack - showPanicData - panic w.Flush() - fmt.Fprint(os.Stderr, "===============================================================\n") + fmt.Fprintln(terminal.Stderr(), "=================================================================") } } diff --git a/term/terminal.go b/term/terminal.go index 8288635b..860c59a6 100644 --- a/term/terminal.go +++ b/term/terminal.go @@ -13,6 +13,13 @@ import ( "golang.org/x/term" ) +// Journey of a message sent to the terminal: +// 1. entrypoint is inputStdout/inputStderr +// 2. data is sent to colorableStdout/colorableStderr +// 3. a copy of the data is also sent to copyStdout/Stderr if enabled (final writer for the copy) +// 4. the data is finally written to the stdin/stdout writers (coming from colorable writers) + +// Terminal gives access to the standard terminal input/output type Terminal struct { stdin io.Reader stdout io.Writer // final writer @@ -20,9 +27,13 @@ type Terminal struct { enableColors maybe.Bool colorableStdout io.Writer // colorable writer colorableStderr io.Writer // colorable writer + copyStdout io.Writer // stdout duplicate + copyStderr io.Writer // stderr duplicate + inputStdout io.Writer // entry for stdout + inputStderr io.Writer // entry for stderr } -func NewTerminal(options ...terminalOption) *Terminal { +func NewTerminal(options ...TerminalOption) *Terminal { t := &Terminal{ stdin: os.Stdin, stdout: os.Stdout, @@ -35,6 +46,20 @@ func NewTerminal(options ...terminalOption) *Terminal { t.colorableStdout = t.getColorableWriter(t.stdout) t.colorableStderr = t.getColorableWriter(t.stderr) + if t.copyStdout == nil { + // no copy, we send to colorable directly + t.inputStdout = t.colorableStdout + } else { + // send to both the output and the copy + t.inputStdout = io.MultiWriter(t.colorableStdout, t.getColorableWriter(t.copyStdout)) + } + if t.copyStderr == nil { + // no copy, we send to colorable directly + t.inputStderr = t.colorableStderr + } else { + // send to both the output and the copy + t.inputStderr = io.MultiWriter(t.colorableStderr, t.getColorableWriter(t.copyStderr)) + } return t } @@ -108,16 +133,17 @@ func (t *Terminal) getColorableWriter(w io.Writer) io.Writer { if file, ok := w.(*os.File); ok && t.enableColors.IsTrueOrUndefined() && (isTerminal(file) || t.enableColors.IsTrue()) { return colorable.NewColorable(file) } - if t.enableColors.IsTrue() { - // output is not a file, but the user doesn't want to strip the colors - return w - } return colorable.NewNonColorable(w) } // FlushAllOutput flushes all buffered output (if supported by the underlying Writer). func (t *Terminal) FlushAllOutput() { - for _, writer := range []io.Writer{t.colorableStdout, t.colorableStderr, t.stdout, t.stderr} { + for _, writer := range []io.Writer{ + t.inputStdout, t.inputStderr, + t.copyStdout, t.copyStderr, + t.colorableStdout, t.colorableStderr, + t.stdout, t.stderr, + } { _, _ = util.FlushWriter(writer) } } @@ -126,20 +152,20 @@ func (t *Terminal) FlushAllOutput() { // Spaces are added between operands when neither is a string. // It returns the number of bytes written and any write error encountered. func (t *Terminal) Print(a ...any) (n int, err error) { - return fmt.Fprint(t.colorableStdout, a...) + return fmt.Fprint(t.inputStdout, a...) } // Println formats using the default formats for its operands and writes to standard output. // Spaces are always added between operands and a newline is appended. // It returns the number of bytes written and any write error encountered. func (t *Terminal) Println(a ...any) (n int, err error) { - return fmt.Fprintln(t.colorableStdout, a...) + return fmt.Fprintln(t.inputStdout, a...) } // Printf formats according to a format specifier and writes to standard output. // It returns the number of bytes written and any write error encountered. func (t *Terminal) Printf(format string, a ...any) (n int, err error) { - return fmt.Fprintf(t.colorableStdout, format, a...) + return fmt.Fprintf(t.inputStdout, format, a...) } func (t *Terminal) Scanln(a ...any) (n int, err error) { @@ -148,15 +174,15 @@ func (t *Terminal) Scanln(a ...any) (n int, err error) { // Write implements the io.Writer interface, writing to the terminal's stdout. func (t *Terminal) Write(p []byte) (n int, err error) { - return t.colorableStdout.Write(p) + return t.inputStdout.Write(p) } func (t *Terminal) Stdout() io.Writer { - return t.colorableStdout + return t.inputStdout } func (t *Terminal) Stderr() io.Writer { - return t.colorableStderr + return t.inputStderr } func isTerminalWriter(w io.Writer) bool { diff --git a/term/terminal_option.go b/term/terminal_option.go index 81dd676c..a503c16a 100644 --- a/term/terminal_option.go +++ b/term/terminal_option.go @@ -8,20 +8,38 @@ import ( "github.com/creativeprojects/resticprofile/util/maybe" ) -type terminalOption func(t *Terminal) +type TerminalOption func(t *Terminal) -func WithStdin(stdin io.Reader) terminalOption { +func WithNoStdin(stdin io.Reader) TerminalOption { + return func(t *Terminal) { + t.stdin = nilReader{} + } +} + +func WithNoStdout() TerminalOption { + return func(t *Terminal) { + t.stdout = io.Discard + } +} + +func WithNoStderr(stderr io.Writer) TerminalOption { + return func(t *Terminal) { + t.stderr = io.Discard + } +} + +func WithStdin(stdin io.Reader) TerminalOption { if stdin == nil { - stdin = nilReader{} + return func(t *Terminal) {} } return func(t *Terminal) { t.stdin = stdin } } -func WithStdout(stdout io.Writer) terminalOption { +func WithStdout(stdout io.Writer) TerminalOption { if stdout == nil { - stdout = io.Discard + return func(t *Terminal) {} } if stdout != os.Stdout && stdout != os.Stderr { stdout = util.NewSyncWriter(stdout) @@ -31,9 +49,9 @@ func WithStdout(stdout io.Writer) terminalOption { } } -func WithStderr(stderr io.Writer) terminalOption { +func WithStderr(stderr io.Writer) TerminalOption { if stderr == nil { - stderr = io.Discard + return func(t *Terminal) {} } if stderr != os.Stdout && stderr != os.Stderr { stderr = util.NewSyncWriter(stderr) @@ -43,20 +61,34 @@ func WithStderr(stderr io.Writer) terminalOption { } } -func WithColors(enable bool) terminalOption { +func WithColors(enable bool) TerminalOption { return func(t *Terminal) { t.enableColors = maybe.SetBool(enable) } } -func WithStdoutRecorder(recorder *Recorder) terminalOption { +func WithStdoutRecorder(recorder *Recorder) TerminalOption { return func(t *Terminal) { t.stdout = recorder.inputWriter } } -func WithStderrRecorder(recorder *Recorder) terminalOption { +func WithStderrRecorder(recorder *Recorder) TerminalOption { return func(t *Terminal) { t.stderr = recorder.inputWriter } } + +// WithStdoutCopy creates a copy of everything sent to Stdout to `copy` writer. +// Colorisation is independent: the Stdout writer can display colors while the copy to a non terminal won't display colors. +func WithStdoutCopy(w io.Writer) TerminalOption { + return func(t *Terminal) { + t.copyStdout = w + } +} + +func WithStderrCopy(w io.Writer) TerminalOption { + return func(t *Terminal) { + t.copyStderr = w + } +} From 7fc967e7614dbd300da9c10c15b5247ed2770c08 Mon Sep 17 00:00:00 2001 From: Fred Date: Mon, 6 Apr 2026 22:16:26 +0100 Subject: [PATCH 11/15] refactor: replace AsyncFileWriter with new write package and improve file handling --- logger.go | 10 +- term/terminal_option.go | 4 +- term/terminal_option_test.go | 45 ++++++ util/filewriter.go | 236 -------------------------------- util/filewriter_test.go | 256 ----------------------------------- util/write/append.go | 29 ++++ util/write/append_test.go | 38 ++++++ util/write/async.go | 119 ++++++++++++++++ util/write/async_option.go | 12 ++ util/write/async_test.go | 94 +++++++++++++ util/write/file.go | 77 +++++++++++ util/write/file_option.go | 25 ++++ util/write/file_test.go | 56 ++++++++ 13 files changed, 501 insertions(+), 500 deletions(-) create mode 100644 term/terminal_option_test.go delete mode 100644 util/filewriter.go delete mode 100644 util/filewriter_test.go create mode 100644 util/write/append.go create mode 100644 util/write/append_test.go create mode 100644 util/write/async.go create mode 100644 util/write/async_option.go create mode 100644 util/write/async_test.go create mode 100644 util/write/file.go create mode 100644 util/write/file_option.go create mode 100644 util/write/file_test.go diff --git a/logger.go b/logger.go index ec100138..b612b72e 100644 --- a/logger.go +++ b/logger.go @@ -17,6 +17,7 @@ import ( "github.com/creativeprojects/resticprofile/term" "github.com/creativeprojects/resticprofile/util" "github.com/creativeprojects/resticprofile/util/collect" + "github.com/creativeprojects/resticprofile/util/write" "github.com/fatih/color" ) @@ -120,7 +121,7 @@ func getFileHandler(logfile string) (*clog.StandardLogHandler, io.Writer, error) } // create a platform aware log file appender - var appender util.AsyncFileWriterAppendFunc + var appender write.WriterAppendFunc if platform.IsWindows() { appender = func(dst []byte, c byte) []byte { switch c { @@ -133,14 +134,11 @@ func getFileHandler(logfile string) (*clog.StandardLogHandler, io.Writer, error) } } - writer, err := util.NewAsyncFileWriter( - logfile, - util.WithAsyncFileAppendFunc(appender), - util.WithAsyncFilePerm(0644), - ) + file, err := write.NewFile(logfile, write.WithFilePerm(0644)) if err != nil { return nil, nil, err } + writer := write.NewAsync(write.NewAppend(file, appender)) return clog.NewStandardLogHandler(writer, "", log.LstdFlags), writer, nil } diff --git a/term/terminal_option.go b/term/terminal_option.go index a503c16a..81da50e3 100644 --- a/term/terminal_option.go +++ b/term/terminal_option.go @@ -10,7 +10,7 @@ import ( type TerminalOption func(t *Terminal) -func WithNoStdin(stdin io.Reader) TerminalOption { +func WithNoStdin() TerminalOption { return func(t *Terminal) { t.stdin = nilReader{} } @@ -22,7 +22,7 @@ func WithNoStdout() TerminalOption { } } -func WithNoStderr(stderr io.Writer) TerminalOption { +func WithNoStderr() TerminalOption { return func(t *Terminal) { t.stderr = io.Discard } diff --git a/term/terminal_option_test.go b/term/terminal_option_test.go new file mode 100644 index 00000000..abfeeaf0 --- /dev/null +++ b/term/terminal_option_test.go @@ -0,0 +1,45 @@ +package term + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTerminalNoInput(t *testing.T) { + terminal := NewTerminal(WithNoStdin()) + pwd, err := terminal.ReadPassword() + require.Error(t, err) + assert.Empty(t, pwd) + + var line string + read, err := terminal.Scanln(&line) + require.Error(t, err) + assert.Empty(t, line) + assert.Empty(t, read) +} + +func TestTerminalNoStdout(t *testing.T) { + terminal := NewTerminal(WithNoStdout()) + written, err := terminal.Print("something") + require.NoError(t, err) + assert.Equal(t, 9, written) +} + +func TestTerminalNoStderr(t *testing.T) { + terminal := NewTerminal(WithNoStderr()) + written, err := fmt.Fprint(terminal.Stderr(), "something") + require.NoError(t, err) + assert.Equal(t, 9, written) +} + +func TestTestTerminalStdoutCopy(t *testing.T) { + buffer := new(bytes.Buffer) + terminal := NewTerminal(WithStdoutCopy(buffer)) + _, err := terminal.Printf("%s test", "copy") + require.NoError(t, err) + assert.Equal(t, "copy test", buffer.String()) +} diff --git a/util/filewriter.go b/util/filewriter.go deleted file mode 100644 index a462022c..00000000 --- a/util/filewriter.go +++ /dev/null @@ -1,236 +0,0 @@ -package util - -import ( - "errors" - "io" - "os" - "sync" - "sync/atomic" - "time" - - "github.com/creativeprojects/resticprofile/platform" -) - -// AsyncFileWriterAppendFunc is called for every input byte when appending it to the output buffer (is similar to buf = append(buf, byte)) -type AsyncFileWriterAppendFunc func(dst []byte, c byte) []byte - -type asyncFileWriterOption func(writer *asyncFileWriter) - -// WithAsyncWriteInterval sets the interval at which writes happen at least when data is pending -func WithAsyncWriteInterval(duration time.Duration) asyncFileWriterOption { - return func(writer *asyncFileWriter) { writer.interval = duration } -} - -// WithAsyncFileKeepOpen toggles whether the file is kept open between writes. Defaults to true for all OS except Windows. -func WithAsyncFileKeepOpen(keepOpen bool) asyncFileWriterOption { - return func(writer *asyncFileWriter) { writer.keepOpen = keepOpen } -} - -// WithAsyncFileAppendFunc sets AsyncFileWriterAppendFunc. Default is to not use a custom appender. -func WithAsyncFileAppendFunc(appender AsyncFileWriterAppendFunc) asyncFileWriterOption { - return func(writer *asyncFileWriter) { writer.appender = appender } -} - -// WithAsyncFilePerm sets file perms to apply when creating the file -func WithAsyncFilePerm(perm os.FileMode) asyncFileWriterOption { - return func(writer *asyncFileWriter) { writer.perm = perm } -} - -// WithAsyncFileFlag sets file open flags -func WithAsyncFileFlag(flag int) asyncFileWriterOption { - return func(writer *asyncFileWriter) { writer.flag = flag } -} - -// WithAsyncFileTruncate enables that existing files are truncated -func WithAsyncFileTruncate() asyncFileWriterOption { - return func(writer *asyncFileWriter) { writer.flag |= os.O_TRUNC } -} - -const ( - asyncWriterDataChanSize = 64 - asyncWriterBlockSize = 4 * 1024 - asyncWriterMaxBlockSize = asyncWriterDataChanSize * asyncWriterBlockSize -) - -var ( - asyncWriterBufferPool = sync.Pool{ - New: func() any { return make([]byte, asyncWriterBlockSize) }, - } -) - -func asyncWriterReturnToPool(data []byte) { - if cap(data) == asyncWriterBlockSize { - asyncWriterBufferPool.Put(data[:0]) //nolint:staticcheck - } -} - -// NewAsyncFileWriter creates a file writer that accumulates Write requests and writes them at a fixed rate (every 250 ms by default) -func NewAsyncFileWriter(filename string, options ...asyncFileWriterOption) (io.WriteCloser, error) { - w := &asyncFileWriter{ - flush: make(chan chan error), - done: make(chan chan error), - data: make(chan []byte, asyncWriterDataChanSize), - perm: 0644, - flag: os.O_WRONLY | os.O_APPEND | os.O_CREATE, - interval: 250 * time.Millisecond, - keepOpen: !platform.IsWindows(), - } - for _, option := range options { - option(w) - } - - var ( - buffer []byte - lastError error - file *os.File - ) - - closeFile := func() { - if file != nil { - lastError = errors.Join(lastError, file.Close()) - file = nil - } - } - - flush := func(alsoEmpty, whenTooBig bool) { - if len(buffer) == 0 && !alsoEmpty { - return - } - if len(buffer) < asyncWriterMaxBlockSize && whenTooBig { - return - } - if file == nil { - file, lastError = os.OpenFile(filename, w.flag, w.perm) - } - if file != nil { - var written int - written, lastError = file.Write(buffer) - if remaining := len(buffer) - written; remaining > 0 { - copy(buffer, buffer[written:]) - buffer = buffer[:remaining] - } else { - buffer = buffer[:0] - } - if w.keepOpen { - _ = file.Sync() - } else { - closeFile() - } - } - } - - // test if we can create the file - buffer = make([]byte, 0, asyncWriterBlockSize) - flush(true, false) - - // data appending - addToBuffer := func(data []byte) { - buffer = append(buffer, data...) // fast path - asyncWriterReturnToPool(data) - flush(false, true) - } - if w.appender != nil { - addToBuffer = func(data []byte) { - for _, c := range data { - buffer = w.appender(buffer, c) - } - asyncWriterReturnToPool(data) - flush(false, true) - } - } - - addPendingData := func(maxCount int) { - for ; maxCount > 0; maxCount-- { - select { - case data, ok := <-w.data: - if ok { - addToBuffer(data) - } else { - return // closed - } - default: - return // no-more-data - } - } - } - - // data transport - go func() { - ticker := time.NewTicker(w.interval) - defer ticker.Stop() - defer closeFile() - - for { - select { - case data := <-w.data: - addToBuffer(data) - case <-ticker.C: - flush(false, false) - case req := <-w.flush: - addPendingData(1024) - flush(false, false) - req <- lastError - case req := <-w.done: - defer func(response chan error) { - response <- lastError - }(req) - close(w.done) - close(w.flush) - close(w.data) - addPendingData(1024) - flush(false, false) - closeFile() - return - } - } - }() - - return w, lastError -} - -var errAlreadyClosed = errors.New("file writer already closed") - -type asyncFileWriter struct { - done, flush chan chan error - data chan []byte - appender AsyncFileWriterAppendFunc - flag int - perm os.FileMode - keepOpen bool - interval time.Duration - closed atomic.Bool -} - -func (w *asyncFileWriter) Close() error { - if !w.closed.CompareAndSwap(false, true) { - return errAlreadyClosed - } - req := make(chan error) - w.done <- req - return <-req -} - -func (w *asyncFileWriter) Flush() error { - if w.closed.Load() { - return errAlreadyClosed - } - req := make(chan error) - w.flush <- req - return <-req -} - -func (w *asyncFileWriter) Write(data []byte) (n int, err error) { - if w.closed.Load() { - return 0, errAlreadyClosed - } - var buffer []byte - if len(data) <= asyncWriterBlockSize { - buffer = asyncWriterBufferPool.Get().([]byte)[:len(data)] - } else { - buffer = make([]byte, len(data)) - } - - n = copy(buffer, data) - w.data <- buffer - return -} diff --git a/util/filewriter_test.go b/util/filewriter_test.go deleted file mode 100644 index 47cf691e..00000000 --- a/util/filewriter_test.go +++ /dev/null @@ -1,256 +0,0 @@ -package util - -import ( - "bytes" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAsyncFileWriterDoubleClose(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - - err = w.Close() - assert.NoError(t, err) - err = w.Close() - assert.Error(t, err) -} - -func TestAsyncFileWriterBasicWrite(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - - n, err := w.Write([]byte("hello world")) - assert.NoError(t, err) - assert.Equal(t, 11, n) - - err = w.Close() - assert.NoError(t, err) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "hello world", string(content)) -} - -func TestAsyncFileWriterMultipleWrites(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - - for _, chunk := range []string{"foo", "bar", "baz"} { - _, err = w.Write([]byte(chunk)) - require.NoError(t, err) - } - - err = w.Close() - assert.NoError(t, err) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "foobarbaz", string(content)) -} - -func TestAsyncFileWriterFlush(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename, - WithAsyncWriteInterval(10*time.Second), // long interval so flush drives the write - ) - require.NoError(t, err) - defer w.Close() - - _, err = w.Write([]byte("flushed")) - require.NoError(t, err) - - fw, ok := w.(*asyncFileWriter) - require.True(t, ok) - err = fw.Flush() - assert.NoError(t, err) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "flushed", string(content)) -} - -func TestAsyncFileWriterTruncate(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - // First write - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - _, err = w.Write([]byte("original")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - // Second write with truncate - w, err = NewAsyncFileWriter(filename, WithAsyncFileTruncate()) - require.NoError(t, err) - _, err = w.Write([]byte("new")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "new", string(content)) -} - -func TestAsyncFileWriterAppendMode(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - // First write - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - _, err = w.Write([]byte("first")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - // Second write appends by default - w, err = NewAsyncFileWriter(filename) - require.NoError(t, err) - _, err = w.Write([]byte("second")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "firstsecond", string(content)) -} - -func TestAsyncFileWriterCustomAppendFunc(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - // appender that uppercases every byte - upperAppender := func(dst []byte, c byte) []byte { - if c >= 'a' && c <= 'z' { - c -= 32 - } - return append(dst, c) - } - - w, err := NewAsyncFileWriter(filename, WithAsyncFileAppendFunc(upperAppender)) - require.NoError(t, err) - _, err = w.Write([]byte("hello")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "HELLO", string(content)) -} - -func TestAsyncFileWriterKeepOpenFalse(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename, WithAsyncFileKeepOpen(false)) - require.NoError(t, err) - _, err = w.Write([]byte("keep-closed")) - require.NoError(t, err) - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "keep-closed", string(content)) -} - -func TestAsyncFileWriterInvalidPath(t *testing.T) { - _, err := NewAsyncFileWriter("/nonexistent/path/test.log") - assert.Error(t, err) -} - -func TestAsyncFileWriterLargeWrite(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - // Write more than asyncWriterBlockSize (4KB) - large := strings.Repeat("x", asyncWriterBlockSize*3) - - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - n, err := w.Write([]byte(large)) - assert.NoError(t, err) - assert.Equal(t, len(large), n) - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, large, string(content)) -} - -func TestAsyncFileWriterWriteInterval(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename, WithAsyncWriteInterval(10*time.Millisecond)) - require.NoError(t, err) - defer w.Close() - - _, err = w.Write([]byte("interval")) - require.NoError(t, err) - - // Wait enough time for the ticker to fire - time.Sleep(50 * time.Millisecond) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, "interval", string(content)) -} - -func TestAsyncFileWriterWriteEmptyData(t *testing.T) { - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - - n, err := w.Write([]byte{}) - assert.NoError(t, err) - assert.Equal(t, 0, n) - - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Empty(t, content) -} - -func TestAsyncFileWriterBufferReuse(t *testing.T) { - // Verify pooled-size writes don't corrupt data due to buffer reuse - dir := t.TempDir() - filename := filepath.Join(dir, "test.log") - - w, err := NewAsyncFileWriter(filename) - require.NoError(t, err) - - var expected bytes.Buffer - for i := range 10 { - chunk := []byte(strings.Repeat(string(rune('a'+i)), asyncWriterBlockSize)) - _, err = w.Write(chunk) - require.NoError(t, err) - expected.Write(chunk) - } - - require.NoError(t, w.Close()) - - content, err := os.ReadFile(filename) - require.NoError(t, err) - assert.Equal(t, expected.String(), string(content)) -} diff --git a/util/write/append.go b/util/write/append.go new file mode 100644 index 00000000..2f1cf25f --- /dev/null +++ b/util/write/append.go @@ -0,0 +1,29 @@ +package write + +import "io" + +// WriterAppendFunc is called for every input byte when appending it to the output buffer (is similar to buf = append(buf, byte)) +type WriterAppendFunc func(dst []byte, c byte) []byte + +type Append struct { + appender WriterAppendFunc + next io.Writer +} + +func NewAppend(w io.Writer, fn WriterAppendFunc) *Append { + return &Append{ + appender: fn, + next: w, + } +} + +func (a *Append) Write(data []byte) (int, error) { + if a.appender == nil { + return a.next.Write(data) + } + dst := make([]byte, 0, len(data)*2) + for _, c := range data { + dst = a.appender(dst, c) + } + return a.next.Write(dst) +} diff --git a/util/write/append_test.go b/util/write/append_test.go new file mode 100644 index 00000000..ca201280 --- /dev/null +++ b/util/write/append_test.go @@ -0,0 +1,38 @@ +package write + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAppendLinesNoAppender(t *testing.T) { + buffer := new(bytes.Buffer) + a := NewAppend(buffer, nil) + _, _ = a.Write([]byte("a\n")) + _, _ = a.Write([]byte("b\n")) + _, _ = a.Write([]byte("c\n")) + + assert.Equal(t, "a\nb\nc\n", buffer.String()) +} + +func TestAppendLinesWithAppender(t *testing.T) { + buffer := new(bytes.Buffer) + appender := func(dst []byte, c byte) []byte { + switch c { + case '\n': + return append(dst, '\r', '\n') // normalize to CRLF on Windows + case '\r': + return dst + } + return append(dst, c) + } + + a := NewAppend(buffer, appender) + _, _ = a.Write([]byte("a\n")) + _, _ = a.Write([]byte("b\n")) + _, _ = a.Write([]byte("c\n")) + + assert.Equal(t, "a\r\nb\r\nc\r\n", buffer.String()) +} diff --git a/util/write/async.go b/util/write/async.go new file mode 100644 index 00000000..bcd55995 --- /dev/null +++ b/util/write/async.go @@ -0,0 +1,119 @@ +package write + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +const ( + asyncWriterDataChanSize = 64 + asyncWriterFlushChanSize = 16 + asyncWriterBlockSize = 4 * 1024 + asyncWriterMaxBlockSize = asyncWriterDataChanSize * asyncWriterBlockSize +) + +var ErrAlreadyClosed = errors.New("writer already closed") + +type Async struct { + handler io.Writer + interval time.Duration + data chan []byte + flusher chan chan struct{} + done chan struct{} + systemGroup sync.WaitGroup + closeOnce sync.Once + closed atomic.Bool +} + +// NewAsync creates a writer that accumulates Write requests and writes them at a fixed rate (every 250 ms by default) +func NewAsync(handler io.Writer, options ...AsyncOption) *Async { + w := &Async{ + handler: handler, + interval: 250 * time.Millisecond, + data: make(chan []byte, asyncWriterDataChanSize), + flusher: make(chan chan struct{}, asyncWriterFlushChanSize), + done: make(chan struct{}), + } + for _, option := range options { + option(w) + } + w.systemGroup.Go(func() { + w.intervalFlush() + }) + w.systemGroup.Go(func() { + w.recvFlush() + }) + return w +} + +func (w *Async) intervalFlush() { + ticker := time.NewTicker(w.interval) + for { + select { + case <-ticker.C: + go w.Flush() + case <-w.done: + ticker.Stop() + return + } + } +} + +func (w *Async) recvFlush() { + for done := range w.flusher { + w.flush() + close(done) + } +} + +// Close the writer. Any more call to Write will be ignored. +func (w *Async) Close() error { + var err error + w.closeOnce.Do(func() { + w.closed.Store(true) + close(w.done) + _ = w.Flush() + close(w.flusher) + w.systemGroup.Wait() + }) + return err +} + +func (w *Async) Flush() error { + done := make(chan struct{}) + w.flusher <- done + // wait until the flusher is done + <-done + return nil +} + +func (w *Async) flush() error { + for { + // keep reading from the channel until it's empty + select { + case data := <-w.data: + _, err := w.handler.Write(data) + if err != nil { + return err + } + default: + return nil + } + } +} + +// Write asynchronously to the handler +func (w *Async) Write(data []byte) (n int, err error) { + if w.closed.Load() { + return 0, fmt.Errorf("cannot write: %w", ErrAlreadyClosed) + } + + buffer := make([]byte, len(data)) + n = copy(buffer, data) + w.data <- buffer + return n, nil +} diff --git a/util/write/async_option.go b/util/write/async_option.go new file mode 100644 index 00000000..9bc89692 --- /dev/null +++ b/util/write/async_option.go @@ -0,0 +1,12 @@ +package write + +import ( + "time" +) + +type AsyncOption func(writer *Async) + +// WithWriteInterval sets the interval at which writes happen at least when data is pending +func WithWriteInterval(duration time.Duration) AsyncOption { + return func(writer *Async) { writer.interval = duration } +} diff --git a/util/write/async_test.go b/util/write/async_test.go new file mode 100644 index 00000000..24a06605 --- /dev/null +++ b/util/write/async_test.go @@ -0,0 +1,94 @@ +package write + +import ( + "bytes" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAsyncWriter(t *testing.T) { + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + n, err := w.Write([]byte("hello world")) + require.NoError(t, err) + assert.Equal(t, 11, n) + + err = w.Close() + require.NoError(t, err) + + assert.Equal(t, "hello world", buffer.String()) +} + +func TestAsyncWriteMoreThanChannelSize(t *testing.T) { + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + for range asyncWriterDataChanSize + 1 { + n, err := w.Write([]byte("aaa")) + require.NoError(t, err) + assert.Equal(t, 3, n) + } + + err := w.Close() + require.NoError(t, err) + + assert.Equal(t, strings.Repeat("aaa", asyncWriterDataChanSize+1), buffer.String()) +} + +func TestAsyncWriteInParallelAndClose(t *testing.T) { + repeat := 10 + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + wg := new(sync.WaitGroup) + for range repeat { + wg.Go(func() { + _, _ = w.Write([]byte("aa")) + }) + } + + require.NoError(t, w.Close()) + wg.Wait() +} + +func TestAsyncWriteInParallelAndWaitBeforeClosing(t *testing.T) { + repeat := 10 + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + wg := new(sync.WaitGroup) + for range repeat { + wg.Go(func() { + _, _ = w.Write([]byte("aa")) + }) + } + + wg.Wait() + require.NoError(t, w.Close()) + + assert.Equal(t, strings.Repeat("aa", repeat), buffer.String()) +} + +func TestAsyncWriteBigBuffersInParallelAndWaitBeforeClosing(t *testing.T) { + repeat := 100 + bufferSize := 1024 * 1024 + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + wg := new(sync.WaitGroup) + for range repeat { + wg.Go(func() { + buffer := make([]byte, bufferSize) + _, _ = w.Write(buffer) + }) + } + + wg.Wait() + require.NoError(t, w.Close()) + assert.Equal(t, repeat*bufferSize, buffer.Len()) +} diff --git a/util/write/file.go b/util/write/file.go new file mode 100644 index 00000000..80aff771 --- /dev/null +++ b/util/write/file.go @@ -0,0 +1,77 @@ +package write + +import ( + "errors" + "os" + + "github.com/creativeprojects/resticprofile/platform" +) + +type File struct { + filename string + perm os.FileMode + flag int + keepOpen bool + handle *os.File +} + +func NewFile(filename string, options ...FileOption) (f *File, err error) { + f = &File{ + filename: filename, + perm: 0644, + flag: os.O_WRONLY | os.O_APPEND | os.O_CREATE, + keepOpen: !platform.IsWindows(), + } + + for _, option := range options { + option(f) + } + + err = f.open() + if !f.keepOpen { + defer func() { + err = errors.Join(err, f.Close()) + }() + } + return +} + +func (f *File) open() error { + if f.handle != nil { + return nil + } + var err error + f.handle, err = os.OpenFile(f.filename, f.flag, f.perm) + return err +} + +func (f *File) Close() error { + if f.handle != nil { + err := f.handle.Close() + f.handle = nil + return err + } + return nil +} + +func (f *File) Flush() error { + if f.handle != nil { + return f.handle.Sync() + } + return nil +} + +func (f *File) Write(data []byte) (n int, err error) { + if !f.keepOpen { + err := f.open() + if err != nil { + return 0, err + } + defer func() { + err = errors.Join(err, f.Close()) + }() + } + n, err = f.handle.Write(data) + err = errors.Join(err, f.Flush()) + return +} diff --git a/util/write/file_option.go b/util/write/file_option.go new file mode 100644 index 00000000..f593dc9d --- /dev/null +++ b/util/write/file_option.go @@ -0,0 +1,25 @@ +package write + +import "os" + +type FileOption func(f *File) + +// WithFileKeepOpen toggles whether the file is kept open between writes. Defaults to true for all OS except Windows. +func WithFileKeepOpen(keepOpen bool) FileOption { + return func(f *File) { f.keepOpen = keepOpen } +} + +// WithFilePerm sets file perms to apply when creating the file +func WithFilePerm(perm os.FileMode) FileOption { + return func(f *File) { f.perm = perm } +} + +// WithFileFlag sets file open flags +func WithFileFlag(flag int) FileOption { + return func(f *File) { f.flag = flag } +} + +// WithFileTruncate enables that existing files are truncated +func WithFileTruncate() FileOption { + return func(f *File) { f.flag |= os.O_TRUNC } +} diff --git a/util/write/file_test.go b/util/write/file_test.go new file mode 100644 index 00000000..e7fd465a --- /dev/null +++ b/util/write/file_test.go @@ -0,0 +1,56 @@ +package write + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileDefaultOption(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename) + require.NoError(t, err) + + n, err := w.Write([]byte("hello world")) + assert.NoError(t, err) + assert.Equal(t, 11, n) + + err = w.Close() + assert.NoError(t, err) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) +} + +func TestFileCloseAfterWrite(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(false)) + require.NoError(t, err) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello", string(content)) + + n, err = w.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + + content, err = os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) + + err = w.Close() + assert.NoError(t, err) +} From c8abab188409794ab30b4e2d93aa27dbee728f96 Mon Sep 17 00:00:00 2001 From: Fred Date: Mon, 6 Apr 2026 22:24:42 +0100 Subject: [PATCH 12/15] refactor: change flusher channel type to error and update flush handling --- util/write/async.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/util/write/async.go b/util/write/async.go index bcd55995..54309667 100644 --- a/util/write/async.go +++ b/util/write/async.go @@ -22,7 +22,7 @@ type Async struct { handler io.Writer interval time.Duration data chan []byte - flusher chan chan struct{} + flusher chan chan error done chan struct{} systemGroup sync.WaitGroup closeOnce sync.Once @@ -35,7 +35,7 @@ func NewAsync(handler io.Writer, options ...AsyncOption) *Async { handler: handler, interval: 250 * time.Millisecond, data: make(chan []byte, asyncWriterDataChanSize), - flusher: make(chan chan struct{}, asyncWriterFlushChanSize), + flusher: make(chan chan error, asyncWriterFlushChanSize), done: make(chan struct{}), } for _, option := range options { @@ -65,8 +65,8 @@ func (w *Async) intervalFlush() { func (w *Async) recvFlush() { for done := range w.flusher { - w.flush() - close(done) + err := w.flush() + done <- err } } @@ -76,7 +76,7 @@ func (w *Async) Close() error { w.closeOnce.Do(func() { w.closed.Store(true) close(w.done) - _ = w.Flush() + err = w.Flush() close(w.flusher) w.systemGroup.Wait() }) @@ -84,11 +84,12 @@ func (w *Async) Close() error { } func (w *Async) Flush() error { - done := make(chan struct{}) + done := make(chan error) w.flusher <- done // wait until the flusher is done - <-done - return nil + err := <-done + close(done) + return err } func (w *Async) flush() error { From 66c04305d91755992930d4e0be5ee6273f4493fb Mon Sep 17 00:00:00 2001 From: Fred Date: Tue, 7 Apr 2026 15:05:06 +0100 Subject: [PATCH 13/15] refactor: enhance file handling with keep open timeout and add stats tracking --- util/write/file.go | 71 ++++++++++++++++++++++++++++++++------- util/write/file_option.go | 11 +++++- util/write/file_test.go | 47 ++++++++++++++++++++++++-- 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/util/write/file.go b/util/write/file.go index 80aff771..002ddcd2 100644 --- a/util/write/file.go +++ b/util/write/file.go @@ -3,24 +3,35 @@ package write import ( "errors" "os" + "sync" + "sync/atomic" + "time" "github.com/creativeprojects/resticprofile/platform" ) type File struct { - filename string - perm os.FileMode - flag int - keepOpen bool - handle *os.File + filename string + perm os.FileMode + flag int + keepOpen bool + keepOpenTimeout time.Duration + handle *os.File + mutex sync.Mutex + timer *time.Timer + timerMutex sync.Mutex + // stats + fileOpenCount atomic.Int32 + fileCloseCount atomic.Int32 } func NewFile(filename string, options ...FileOption) (f *File, err error) { f = &File{ - filename: filename, - perm: 0644, - flag: os.O_WRONLY | os.O_APPEND | os.O_CREATE, - keepOpen: !platform.IsWindows(), + filename: filename, + perm: 0644, + flag: os.O_WRONLY | os.O_APPEND | os.O_CREATE, + keepOpen: !platform.IsWindows(), + keepOpenTimeout: 10 * time.Millisecond, } for _, option := range options { @@ -37,17 +48,25 @@ func NewFile(filename string, options ...FileOption) (f *File, err error) { } func (f *File) open() error { + f.mutex.Lock() + defer f.mutex.Unlock() + if f.handle != nil { return nil } var err error f.handle, err = os.OpenFile(f.filename, f.flag, f.perm) + f.fileOpenCount.Add(1) return err } func (f *File) Close() error { + f.mutex.Lock() + defer f.mutex.Unlock() + if f.handle != nil { err := f.handle.Close() + f.fileCloseCount.Add(1) f.handle = nil return err } @@ -55,6 +74,9 @@ func (f *File) Close() error { } func (f *File) Flush() error { + f.mutex.Lock() + defer f.mutex.Unlock() + if f.handle != nil { return f.handle.Sync() } @@ -63,15 +85,38 @@ func (f *File) Flush() error { func (f *File) Write(data []byte) (n int, err error) { if !f.keepOpen { + f.stopCloseTimer() err := f.open() if err != nil { return 0, err } - defer func() { - err = errors.Join(err, f.Close()) - }() + defer f.resetCloseTimer() } n, err = f.handle.Write(data) - err = errors.Join(err, f.Flush()) return } + +func (f *File) stopCloseTimer() { + f.timerMutex.Lock() + defer f.timerMutex.Unlock() + if f.timer != nil { + f.timer.Stop() + } +} + +func (f *File) resetCloseTimer() { + f.timerMutex.Lock() + defer f.timerMutex.Unlock() + + if f.timer != nil { + f.timer.Stop() + } + f.timer = time.AfterFunc(f.keepOpenTimeout, func() { + _ = f.Close() + }) +} + +// stats returns the number of times the file was opened and closed +func (f *File) stats() (int32, int32) { + return f.fileOpenCount.Load(), f.fileCloseCount.Load() +} diff --git a/util/write/file_option.go b/util/write/file_option.go index f593dc9d..e59078ed 100644 --- a/util/write/file_option.go +++ b/util/write/file_option.go @@ -1,6 +1,9 @@ package write -import "os" +import ( + "os" + "time" +) type FileOption func(f *File) @@ -9,6 +12,12 @@ func WithFileKeepOpen(keepOpen bool) FileOption { return func(f *File) { f.keepOpen = keepOpen } } +// WithFileKeepOpenTimeout will automatically close the file after there was no more write during `timeout`. +// It defaults to 10ms. +func WithFileKeepOpenTimeout(timeout time.Duration) FileOption { + return func(f *File) { f.keepOpenTimeout = timeout } +} + // WithFilePerm sets file perms to apply when creating the file func WithFilePerm(perm os.FileMode) FileOption { return func(f *File) { f.perm = perm } diff --git a/util/write/file_test.go b/util/write/file_test.go index e7fd465a..65398c45 100644 --- a/util/write/file_test.go +++ b/util/write/file_test.go @@ -4,6 +4,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -13,16 +14,21 @@ func TestFileDefaultOption(t *testing.T) { dir := t.TempDir() filename := filepath.Join(dir, "testfile") - w, err := NewFile(filename) + w, err := NewFile(filename, WithFileKeepOpen(true), WithFileKeepOpenTimeout(1*time.Millisecond)) require.NoError(t, err) n, err := w.Write([]byte("hello world")) assert.NoError(t, err) assert.Equal(t, 11, n) + time.Sleep(2 * time.Millisecond) err = w.Close() assert.NoError(t, err) + opened, closed := w.stats() + assert.Equal(t, int32(1), opened) + assert.Equal(t, int32(1), closed) + content, err := os.ReadFile(filename) require.NoError(t, err) assert.Equal(t, "hello world", string(content)) @@ -32,7 +38,40 @@ func TestFileCloseAfterWrite(t *testing.T) { dir := t.TempDir() filename := filepath.Join(dir, "testfile") - w, err := NewFile(filename, WithFileKeepOpen(false)) + w, err := NewFile(filename, WithFileKeepOpen(false), WithFileKeepOpenTimeout(1*time.Millisecond)) + require.NoError(t, err) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + time.Sleep(2 * time.Millisecond) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello", string(content)) + + n, err = w.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + time.Sleep(2 * time.Millisecond) + + content, err = os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) + + err = w.Close() + assert.NoError(t, err) + + opened, closed := w.stats() + assert.Equal(t, int32(3), opened) // 1 during instantiation + 1 for each write + assert.Equal(t, int32(3), closed) // 1 during instantiation + 1 for each write +} + +func TestFileNoTimeToCloseAfterWrite(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(false), WithFileKeepOpenTimeout(1*time.Second)) require.NoError(t, err) n, err := w.Write([]byte("hello")) @@ -53,4 +92,8 @@ func TestFileCloseAfterWrite(t *testing.T) { err = w.Close() assert.NoError(t, err) + + opened, closed := w.stats() + assert.Equal(t, int32(2), opened) // 1 during instantiation + 1 for both write + assert.Equal(t, int32(2), closed) // 1 during instantiation + 1 for both write } From 890b393bad252db3b1ff872d8089086eefad270c Mon Sep 17 00:00:00 2001 From: Fred Date: Tue, 7 Apr 2026 15:31:27 +0100 Subject: [PATCH 14/15] refactor: improve async writer with flusher state management and add flush test for file writer --- util/write/async.go | 31 ++++++++++++++++++------------- util/write/async_test.go | 16 ++++++++++++++++ util/write/file_test.go | 29 +++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/util/write/async.go b/util/write/async.go index 54309667..bb9ab63f 100644 --- a/util/write/async.go +++ b/util/write/async.go @@ -12,21 +12,20 @@ import ( const ( asyncWriterDataChanSize = 64 asyncWriterFlushChanSize = 16 - asyncWriterBlockSize = 4 * 1024 - asyncWriterMaxBlockSize = asyncWriterDataChanSize * asyncWriterBlockSize ) var ErrAlreadyClosed = errors.New("writer already closed") type Async struct { - handler io.Writer - interval time.Duration - data chan []byte - flusher chan chan error - done chan struct{} - systemGroup sync.WaitGroup - closeOnce sync.Once - closed atomic.Bool + handler io.Writer + interval time.Duration + data chan []byte + flusher chan chan error + done chan struct{} + systemGroup sync.WaitGroup + closeOnce sync.Once + closed atomic.Bool + flusherClosed atomic.Bool } // NewAsync creates a writer that accumulates Write requests and writes them at a fixed rate (every 250 ms by default) @@ -36,7 +35,7 @@ func NewAsync(handler io.Writer, options ...AsyncOption) *Async { interval: 250 * time.Millisecond, data: make(chan []byte, asyncWriterDataChanSize), flusher: make(chan chan error, asyncWriterFlushChanSize), - done: make(chan struct{}), + done: make(chan struct{}), // channel closed after the first call to Close() } for _, option := range options { option(w) @@ -55,7 +54,7 @@ func (w *Async) intervalFlush() { for { select { case <-ticker.C: - go w.Flush() + w.flusher <- nil case <-w.done: ticker.Stop() return @@ -66,7 +65,9 @@ func (w *Async) intervalFlush() { func (w *Async) recvFlush() { for done := range w.flusher { err := w.flush() - done <- err + if done != nil { // some calls don't need to wait for the answer + done <- err + } } } @@ -77,6 +78,7 @@ func (w *Async) Close() error { w.closed.Store(true) close(w.done) err = w.Flush() + w.flusherClosed.Store(true) close(w.flusher) w.systemGroup.Wait() }) @@ -84,6 +86,9 @@ func (w *Async) Close() error { } func (w *Async) Flush() error { + if w.flusherClosed.Load() { + return fmt.Errorf("cannot write: %w", ErrAlreadyClosed) + } done := make(chan error) w.flusher <- done // wait until the flusher is done diff --git a/util/write/async_test.go b/util/write/async_test.go index 24a06605..31090a6f 100644 --- a/util/write/async_test.go +++ b/util/write/async_test.go @@ -24,6 +24,22 @@ func TestAsyncWriter(t *testing.T) { assert.Equal(t, "hello world", buffer.String()) } +func TestAsyncWriterFlushAfterClose(t *testing.T) { + buffer := new(bytes.Buffer) + w := NewAsync(buffer) + + n, err := w.Write([]byte("hello world")) + require.NoError(t, err) + assert.Equal(t, 11, n) + + require.NoError(t, w.Flush()) + require.NoError(t, w.Close()) + + require.ErrorIs(t, w.Flush(), ErrAlreadyClosed) + + assert.Equal(t, "hello world", buffer.String()) +} + func TestAsyncWriteMoreThanChannelSize(t *testing.T) { buffer := new(bytes.Buffer) w := NewAsync(buffer) diff --git a/util/write/file_test.go b/util/write/file_test.go index 65398c45..30906385 100644 --- a/util/write/file_test.go +++ b/util/write/file_test.go @@ -97,3 +97,32 @@ func TestFileNoTimeToCloseAfterWrite(t *testing.T) { assert.Equal(t, int32(2), opened) // 1 during instantiation + 1 for both write assert.Equal(t, int32(2), closed) // 1 during instantiation + 1 for both write } + +func TestFileCanFlush(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename) + require.NoError(t, err) + assert.NoError(t, w.Flush()) + + n, err := w.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.NoError(t, w.Flush()) + + n, err = w.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + assert.NoError(t, w.Flush()) + + content, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) + + assert.NoError(t, w.Flush()) + err = w.Close() + assert.NoError(t, err) + + assert.NoError(t, w.Flush()) +} From 2816684bafb2ed5478d12d0940c3574580190c75 Mon Sep 17 00:00:00 2001 From: Fred Date: Tue, 7 Apr 2026 16:34:09 +0100 Subject: [PATCH 15/15] refactor: implement Close method for Append and enhance file handling with error checks --- util/write/append.go | 10 ++++++++++ util/write/append_test.go | 13 +++++++++++++ util/write/async.go | 5 +++++ util/write/file.go | 11 +++++++++-- util/write/file_test.go | 12 ++++++++++++ util/write/mock_write_closer_test.go | 25 +++++++++++++++++++++++++ 6 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 util/write/mock_write_closer_test.go diff --git a/util/write/append.go b/util/write/append.go index 2f1cf25f..cb263a29 100644 --- a/util/write/append.go +++ b/util/write/append.go @@ -27,3 +27,13 @@ func (a *Append) Write(data []byte) (int, error) { } return a.next.Write(dst) } + +// Close will close the destination writer if accessible +func (a *Append) Close() error { + if closer, ok := a.next.(io.Closer); ok { + return closer.Close() + } + return nil +} + +var _ io.WriteCloser = &Append{} diff --git a/util/write/append_test.go b/util/write/append_test.go index ca201280..502b1891 100644 --- a/util/write/append_test.go +++ b/util/write/append_test.go @@ -36,3 +36,16 @@ func TestAppendLinesWithAppender(t *testing.T) { assert.Equal(t, "a\r\nb\r\nc\r\n", buffer.String()) } + +func TestAppendCallsCloseOnTarget(t *testing.T) { + called := 0 + mockClose := newMockWriteCloser(nil, func() error { + called++ + return nil + }) + + a := NewAppend(mockClose, nil) + a.Close() + + assert.Equal(t, 1, called) +} diff --git a/util/write/async.go b/util/write/async.go index bb9ab63f..e1876143 100644 --- a/util/write/async.go +++ b/util/write/async.go @@ -54,7 +54,9 @@ func (w *Async) intervalFlush() { for { select { case <-ticker.C: + ticker.Stop() // don't keep piling up if the flusher channel is already full w.flusher <- nil + ticker.Reset(w.interval) case <-w.done: ticker.Stop() return @@ -81,6 +83,9 @@ func (w *Async) Close() error { w.flusherClosed.Store(true) close(w.flusher) w.systemGroup.Wait() + if closer, ok := w.handler.(io.Closer); ok { + err = errors.Join(err, closer.Close()) + } }) return err } diff --git a/util/write/file.go b/util/write/file.go index 002ddcd2..9ccb0784 100644 --- a/util/write/file.go +++ b/util/write/file.go @@ -10,6 +10,8 @@ import ( "github.com/creativeprojects/resticprofile/platform" ) +var ErrAttemptToWriteOnClosedFile = errors.New("cannot write to a closed or unopened file") + type File struct { filename string perm os.FileMode @@ -55,8 +57,8 @@ func (f *File) open() error { return nil } var err error - f.handle, err = os.OpenFile(f.filename, f.flag, f.perm) f.fileOpenCount.Add(1) + f.handle, err = os.OpenFile(f.filename, f.flag, f.perm) return err } @@ -65,8 +67,8 @@ func (f *File) Close() error { defer f.mutex.Unlock() if f.handle != nil { - err := f.handle.Close() f.fileCloseCount.Add(1) + err := f.handle.Close() f.handle = nil return err } @@ -92,6 +94,11 @@ func (f *File) Write(data []byte) (n int, err error) { } defer f.resetCloseTimer() } + + if f.handle == nil { + return 0, ErrAttemptToWriteOnClosedFile + } + n, err = f.handle.Write(data) return } diff --git a/util/write/file_test.go b/util/write/file_test.go index 30906385..646de17f 100644 --- a/util/write/file_test.go +++ b/util/write/file_test.go @@ -126,3 +126,15 @@ func TestFileCanFlush(t *testing.T) { assert.NoError(t, w.Flush()) } + +func TestFileWriteOnClosedFile(t *testing.T) { + dir := t.TempDir() + filename := filepath.Join(dir, "testfile") + + w, err := NewFile(filename, WithFileKeepOpen(true)) + require.NoError(t, err) + assert.NoError(t, w.Close()) + + _, err = w.Write([]byte("aaa")) + assert.ErrorIs(t, err, ErrAttemptToWriteOnClosedFile) +} diff --git a/util/write/mock_write_closer_test.go b/util/write/mock_write_closer_test.go new file mode 100644 index 00000000..0c5e1e83 --- /dev/null +++ b/util/write/mock_write_closer_test.go @@ -0,0 +1,25 @@ +package write + +import "io" + +type mockWriteCloser struct { + writer func(data []byte) (int, error) + closer func() error +} + +func newMockWriteCloser(writer func(data []byte) (int, error), closer func() error) mockWriteCloser { + return mockWriteCloser{ + writer: writer, + closer: closer, + } +} + +func (m mockWriteCloser) Write(data []byte) (int, error) { + return m.writer(data) +} + +func (m mockWriteCloser) Close() error { + return m.closer() +} + +var _ io.WriteCloser = &Append{}