diff --git a/context.go b/context.go index 2015306..3b1ca86 100644 --- a/context.go +++ b/context.go @@ -40,6 +40,7 @@ package log import ( "context" "fmt" + "log/slog" "github.com/sirupsen/logrus" ) @@ -127,6 +128,9 @@ func SetLevel(level string) error { } L.Logger.SetLevel(lvl) + if slogOut != nil { + slogLevel.Set(logrusToSlogLevel(lvl)) + } return nil } @@ -155,15 +159,26 @@ func SetFormat(format OutputFormat) error { TimestampFormat: RFC3339NanoFixed, FullTimestamp: true, }) - return nil case JSONFormat: L.Logger.SetFormatter(&logrus.JSONFormatter{ TimestampFormat: RFC3339NanoFixed, }) - return nil default: return fmt.Errorf("unknown log format: %s", format) } + + if slogOut != nil { + var handler slog.Handler + switch format { + case TextFormat: + handler = slog.NewTextHandler(slogOut, &slog.HandlerOptions{Level: slogLevel}) + case JSONFormat: + handler = slog.NewJSONHandler(slogOut, &slog.HandlerOptions{Level: slogLevel}) + } + slog.SetDefault(slog.New(handler)) + } + + return nil } // WithLogger returns a new context with the provided logger. Use in diff --git a/slog.go b/slog.go new file mode 100644 index 0000000..dc1db2a --- /dev/null +++ b/slog.go @@ -0,0 +1,95 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package log + +import ( + "context" + "io" + "log/slog" + "sync" + + "github.com/sirupsen/logrus" +) + +// slogOut is used to set the slog logger when setting output format. +var slogOut io.Writer + +// slogLevel is used to control the slog handler's level when slog output is active. +var slogLevel = &slog.LevelVar{} + +// slogOnce guards UseSlog so repeated calls do not stack up hooks or +// reset slogOut to the discard writer installed on the first call. +var slogOnce sync.Once + +func UseSlog() { + slogOnce.Do(func() { + L.Logger.SetNoLock() + L.Logger.AddHook(slogHook{}) + slogOut = L.Logger.Out + L.Logger.SetOutput(io.Discard) + slogLevel.Set(logrusToSlogLevel(L.Logger.GetLevel())) + }) +} + +type slogHook struct{} + +func (hook slogHook) Levels() []logrus.Level { + return logrus.AllLevels +} + +func logrusToSlogLevel(l logrus.Level) slog.Level { + switch l { + case logrus.PanicLevel: + return slog.LevelError + 4 + case logrus.FatalLevel: + return slog.LevelError + 2 + case logrus.ErrorLevel: + return slog.LevelError + case logrus.WarnLevel: + return slog.LevelWarn + case logrus.DebugLevel: + return slog.LevelDebug + case logrus.TraceLevel: + return slog.LevelDebug - 4 + default: + return slog.LevelInfo + } +} + +func (hook slogHook) Fire(entry *logrus.Entry) error { + level := logrusToSlogLevel(entry.Level) + + handler := slog.Default().Handler() + + ctx := entry.Context + if ctx == nil { + ctx = context.Background() + } + + if !handler.Enabled(ctx, level) { + return nil + } + + record := slog.NewRecord(entry.Time, level, entry.Message, 0) + + // Convert logrus fields to slog attributes. + for k, v := range entry.Data { + record.AddAttrs(slog.Any(k, v)) + } + + return handler.Handle(ctx, record) +} diff --git a/slog_test.go b/slog_test.go new file mode 100644 index 0000000..950978f --- /dev/null +++ b/slog_test.go @@ -0,0 +1,275 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package log + +import ( + "bytes" + "context" + "log/slog" + "strings" + "sync" + "testing" + + "github.com/sirupsen/logrus" +) + +// setupSlogTest sets up UseSlogHook with a captured slog buffer and a separate +// buffer for the original logrus output. It restores all global state on cleanup. +func setupSlogTest(t *testing.T) (slogBuf, logrusBuf *bytes.Buffer) { + t.Helper() + + // Save global state to restore later. + oldLogger := L.Logger + oldDefault := slog.Default() + oldSlogOut := slogOut + + // Create a fresh logrus logger so we don't mutate the real global. + logger := logrus.New() + logger.SetLevel(logrus.TraceLevel) + logrusBuf = &bytes.Buffer{} + logger.SetOutput(logrusBuf) + + L = &Entry{ + Logger: logger, + Data: make(Fields, 6), + } + + // Reset slogOnce so UseSlog runs against this fresh logger. + slogOnce = sync.Once{} + + // Activate the slog hook — this redirects logrus output to io.Discard + // and sets slogOut to the logger's original output. + UseSlog() + + // Now install a slog handler that writes to our test buffer. + slogBuf = &bytes.Buffer{} + handler := slog.NewTextHandler(slogBuf, &slog.HandlerOptions{ + Level: slog.LevelDebug - 4, // capture all levels including trace + }) + slog.SetDefault(slog.New(handler)) + + t.Cleanup(func() { + L = &Entry{ + Logger: oldLogger, + Data: make(Fields, 6), + } + slog.SetDefault(oldDefault) + slogOut = oldSlogOut + slogOnce = sync.Once{} + }) + + return slogBuf, logrusBuf +} + +func TestUseSlogHook(t *testing.T) { + slogBuf, logrusBuf := setupSlogTest(t) + + L.Info("hello from L") + + slogOutput := slogBuf.String() + logrusOutput := logrusBuf.String() + + if !strings.Contains(slogOutput, "hello from L") { + t.Errorf("expected slog output to contain message, got: %s", slogOutput) + } + if logrusOutput != "" { + t.Errorf("expected no logrus output, got: %s", logrusOutput) + } +} + +func TestUseSlogHookWithFields(t *testing.T) { + slogBuf, logrusBuf := setupSlogTest(t) + + L.WithFields(Fields{ + "component": "test", + "count": 42, + }).Warn("something happened") + + slogOutput := slogBuf.String() + + if !strings.Contains(slogOutput, "something happened") { + t.Errorf("expected slog output to contain message, got: %s", slogOutput) + } + if !strings.Contains(slogOutput, "component=test") { + t.Errorf("expected slog output to contain component field, got: %s", slogOutput) + } + if !strings.Contains(slogOutput, "count=42") { + t.Errorf("expected slog output to contain count field, got: %s", slogOutput) + } + if logrusBuf.Len() != 0 { + t.Errorf("expected no logrus output, got: %s", logrusBuf.String()) + } +} + +func TestUseSlogHookWithContext(t *testing.T) { + slogBuf, logrusBuf := setupSlogTest(t) + + ctx := context.Background() + logger := G(ctx).WithField("request_id", "abc123") + ctx = WithLogger(ctx, logger) + + G(ctx).Info("context logger message") + + slogOutput := slogBuf.String() + + if !strings.Contains(slogOutput, "context logger message") { + t.Errorf("expected slog output to contain message, got: %s", slogOutput) + } + if !strings.Contains(slogOutput, "request_id=abc123") { + t.Errorf("expected slog output to contain request_id field, got: %s", slogOutput) + } + if logrusBuf.Len() != 0 { + t.Errorf("expected no logrus output, got: %s", logrusBuf.String()) + } +} + +func TestUseSlogHookLevels(t *testing.T) { + slogBuf, logrusBuf := setupSlogTest(t) + + tests := []struct { + name string + logFunc func(string, ...any) + message string + }{ + {"trace", L.Tracef, "trace-msg"}, + {"debug", L.Debugf, "debug-msg"}, + {"info", L.Infof, "info-msg"}, + {"warn", L.Warnf, "warn-msg"}, + {"error", L.Errorf, "error-msg"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + slogBuf.Reset() + logrusBuf.Reset() + + tc.logFunc(tc.message) + + if !strings.Contains(slogBuf.String(), tc.message) { + t.Errorf("expected slog output to contain %q, got: %s", tc.message, slogBuf.String()) + } + if logrusBuf.Len() != 0 { + t.Errorf("expected no logrus output, got: %s", logrusBuf.String()) + } + }) + } +} + +func TestSetFormatWithSlog(t *testing.T) { + // SetFormat reconfigures the slog default handler to write to slogOut. + // After SetFormat, logging through L should still go to slog (via slogOut), + // and nothing should go to the logrus output (which is io.Discard). + _, _ = setupSlogTest(t) + + // Replace slogOut with our own buffer so we can capture what SetFormat configures. + var slogBuf bytes.Buffer + slogOut = &slogBuf + + t.Run("text format", func(t *testing.T) { + slogBuf.Reset() + + if err := SetFormat(TextFormat); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + L.Info("text format message") + + if !strings.Contains(slogBuf.String(), "text format message") { + t.Errorf("expected slog output to contain message, got: %s", slogBuf.String()) + } + }) + + t.Run("json format", func(t *testing.T) { + slogBuf.Reset() + + if err := SetFormat(JSONFormat); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + L.Info("json format message") + + slogOutput := slogBuf.String() + if !strings.Contains(slogOutput, "json format message") { + t.Errorf("expected slog output to contain message, got: %s", slogOutput) + } + if !strings.Contains(slogOutput, "{") { + t.Errorf("expected JSON format output, got: %s", slogOutput) + } + }) +} + +func TestSetLevelWithSlog(t *testing.T) { + slogBuf, _ := setupSlogTest(t) + + // Set level to warn — debug/info messages should be suppressed by slog. + if err := SetLevel("warn"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Also reconfigure slog handler to use slogLevel (as SetFormat does). + slogOut = slogBuf + if err := SetFormat(TextFormat); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + slogBuf.Reset() + L.Info("should be hidden") + if slogBuf.Len() != 0 { + t.Errorf("expected info message to be suppressed at warn level, got: %s", slogBuf.String()) + } + + slogBuf.Reset() + L.Warn("should be visible") + if !strings.Contains(slogBuf.String(), "should be visible") { + t.Errorf("expected warn message to appear, got: %s", slogBuf.String()) + } + + // Raise level back to debug — info should now appear. + if err := SetLevel("debug"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + slogBuf.Reset() + L.Info("now visible") + if !strings.Contains(slogBuf.String(), "now visible") { + t.Errorf("expected info message to appear at debug level, got: %s", slogBuf.String()) + } +} + +func TestLogrusToSlogLevel(t *testing.T) { + tests := []struct { + logrusLevel logrus.Level + slogLevel slog.Level + }{ + {logrus.PanicLevel, slog.LevelError + 4}, + {logrus.FatalLevel, slog.LevelError + 2}, + {logrus.ErrorLevel, slog.LevelError}, + {logrus.WarnLevel, slog.LevelWarn}, + {logrus.InfoLevel, slog.LevelInfo}, + {logrus.DebugLevel, slog.LevelDebug}, + {logrus.TraceLevel, slog.LevelDebug - 4}, + } + + for _, tc := range tests { + t.Run(tc.logrusLevel.String(), func(t *testing.T) { + got := logrusToSlogLevel(tc.logrusLevel) + if got != tc.slogLevel { + t.Errorf("logrusToSlogLevel(%v) = %v, want %v", tc.logrusLevel, got, tc.slogLevel) + } + }) + } +}