diff --git a/monitor/hook/sender.go b/monitor/hook/sender.go index e420490f..51633ef1 100644 --- a/monitor/hook/sender.go +++ b/monitor/hook/sender.go @@ -20,6 +20,7 @@ import ( "github.com/creativeprojects/clog" "github.com/creativeprojects/resticprofile/config" "github.com/creativeprojects/resticprofile/constants" + "github.com/creativeprojects/resticprofile/util" "github.com/creativeprojects/resticprofile/util/templates" ) @@ -71,12 +72,17 @@ func NewSender(certificates []string, userAgent string, timeout time.Duration, d } } -func (s *Sender) Send(cfg config.SendMonitoringSection, ctx Context) error { +func (s *Sender) Send(cfg config.SendMonitoringSection, ctx Context, env *util.Environment) error { if cfg.URL.Value() == "" { return errors.New("URL field is empty") } - url := resolveURL(cfg.URL.Value(), ctx) - publicUrl := resolveURL(cfg.URL.String(), ctx) + + if env == nil { + env = util.NewDefaultEnvironment(os.Environ()...) + } + + url := resolveURL(cfg.URL.Value(), ctx, env) + publicUrl := resolveURL(cfg.URL.String(), ctx, env) method := cfg.Method if method == "" { method = http.MethodGet @@ -94,7 +100,7 @@ func (s *Sender) Send(cfg config.SendMonitoringSection, ctx Context) error { bodyReader = bytes.NewBufferString(body) } if cfg.Body != "" { - body = resolveBody(cfg.Body, ctx) + body = resolveBody(cfg.Body, ctx, env) bodyReader = bytes.NewBufferString(body) } @@ -203,7 +209,7 @@ func getRootCAs(certificates []string) *x509.CertPool { return caCertPool } -func resolveBody(body string, ctx Context) string { +func resolveBody(body string, ctx Context, env *util.Environment) string { return os.Expand(body, func(s string) string { switch s { case constants.EnvProfileName: @@ -228,12 +234,12 @@ func resolveBody(body string, ctx Context) string { return "$" // allow to escape "$" as "$$" default: - return os.Getenv(s) + return env.Get(s) } }) } -func resolveURL(url string, ctx Context) string { +func resolveURL(url string, ctx Context, env *util.Environment) string { return os.Expand(url, func(s string) string { switch s { case constants.EnvProfileName: @@ -258,7 +264,7 @@ func resolveURL(url string, ctx Context) string { return "$" // allow to escape "$" as "$$" default: - return os.Getenv(s) + return env.Get(s) } }) } diff --git a/monitor/hook/sender_test.go b/monitor/hook/sender_test.go index 705498e2..bb7f5980 100644 --- a/monitor/hook/sender_test.go +++ b/monitor/hook/sender_test.go @@ -16,9 +16,10 @@ import ( "github.com/creativeprojects/clog" "github.com/creativeprojects/resticprofile/config" + "github.com/creativeprojects/resticprofile/constants" + "github.com/creativeprojects/resticprofile/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/creativeprojects/resticprofile/constants" ) func TestSend(t *testing.T) { @@ -95,7 +96,7 @@ func TestSend(t *testing.T) { } sender := NewSender(nil, "resticprofile_test", 10*time.Second, false) - err := sender.Send(testCase.cfg, ctx) + err := sender.Send(testCase.cfg, ctx, nil) assert.NoError(t, err) assert.Equal(t, testCase.calls, calls) @@ -114,7 +115,7 @@ func TestDryRun(t *testing.T) { sender := NewSender(nil, "", time.Second, true) err := sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), - }, Context{}) + }, Context{}, nil) assert.NoError(t, err) assert.Equal(t, uint32(0), atomic.LoadUint32(&calls)) @@ -133,7 +134,7 @@ func TestSenderTimeout(t *testing.T) { sender := NewSender(nil, "resticprofile_test", 300*time.Millisecond, false) err := sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), - }, Context{}) + }, Context{}, nil) assert.Error(t, err) assert.Equal(t, uint32(1), atomic.LoadUint32(&startedCalls)) @@ -151,7 +152,7 @@ func TestInsecureRequests(t *testing.T) { // 1: request will fail TLS err := sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), - }, Context{}) + }, Context{}, nil) assert.Error(t, err) assert.Equal(t, 0, calls) @@ -159,7 +160,7 @@ func TestInsecureRequests(t *testing.T) { err = sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), SkipTLS: true, - }, Context{}) + }, Context{}, nil) assert.NoError(t, err) assert.Equal(t, 1, calls) } @@ -179,7 +180,7 @@ func TestRequestWithCA(t *testing.T) { // 1: request will fail TLS err := sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), - }, Context{}) + }, Context{}, nil) assert.Error(t, err) assert.Equal(t, 0, calls) @@ -196,7 +197,7 @@ func TestRequestWithCA(t *testing.T) { sender = NewSender([]string{filename}, "resticprofile_test", 300*time.Millisecond, false) err = sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), - }, Context{}) + }, Context{}, nil) assert.NoError(t, err) assert.Equal(t, 1, calls) } @@ -210,7 +211,7 @@ func TestFailedRequest(t *testing.T) { sender := NewSender(nil, "resticprofile_test", 300*time.Millisecond, false) err := sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(server.URL), - }, Context{}) + }, Context{}, nil) assert.Error(t, err) } @@ -230,7 +231,7 @@ func TestUserAgent(t *testing.T) { Headers: []config.SendMonitoringHeader{ {Name: agentHeader, Value: config.NewConfidentialValue(testAgent)}, }, - }, Context{}) + }, Context{}, nil) assert.NoError(t, err) assert.Equal(t, 1, calls) } @@ -269,7 +270,7 @@ func TestConfidentialURL(t *testing.T) { config.ProcessConfidentialValues(profile) sender := NewSender(nil, "", 300*time.Millisecond, false) - err := sender.Send(profile.Backup.SendBefore[0], Context{}) + err := sender.Send(profile.Backup.SendBefore[0], Context{}, nil) require.NoError(t, err) assert.Equal(t, 1, calls) } @@ -320,7 +321,7 @@ func TestURLEncoding(t *testing.T) { sender := NewSender(nil, "", 300*time.Millisecond, false) err := sender.Send(config.SendMonitoringSection{ URL: config.NewConfidentialValue(serverURL), - }, ctx) + }, ctx, nil) assert.NoError(t, err) assert.Equal(t, 1, calls) } @@ -359,7 +360,7 @@ func TestConfidentialHeader(t *testing.T) { config.ProcessConfidentialValues(profile) sender := NewSender(nil, "", 300*time.Millisecond, false) - err := sender.Send(profile.Backup.SendBefore[0], Context{}) + err := sender.Send(profile.Backup.SendBefore[0], Context{}, nil) require.NoError(t, err) assert.Equal(t, 1, calls) } @@ -393,7 +394,7 @@ func TestParseTemplate(t *testing.T) { URL: config.NewConfidentialValue(server.URL), Method: http.MethodPost, BodyTemplate: filename, - }, ctx) + }, ctx, nil) assert.NoError(t, err) } @@ -418,3 +419,33 @@ func TestResponseSanitizer(t *testing.T) { assert.Equal(t, test[1], responseContentSanitizer.ReplaceAllString(test[0], " "), "test #%d", i) } } + +func TestCustomEnv(t *testing.T) { + calls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + buffer := bytes.Buffer{} + _, err := buffer.ReadFrom(r.Body) + assert.NoError(t, err) + r.Body.Close() + + assert.Equal(t, "some test value\n", buffer.String()) + + calls++ + })) + defer server.Close() + + t.Setenv("TEST_MONITOR_URL", "should never be read") + t.Setenv("SOME_OS_ENV", "should never be read") + + env := util.NewDefaultEnvironment() + env.Put("TEST_MONITOR_URL", server.URL) + env.Put("TEST_BODY_VALUE", "some test value") + + sender := NewSender(nil, "", 300*time.Millisecond, false) + err := sender.Send(config.SendMonitoringSection{ + URL: config.NewConfidentialValue("$TEST_MONITOR_URL"), + Body: "$TEST_BODY_VALUE\n$SOME_OS_ENV", + }, Context{}, env) + assert.NoError(t, err) + assert.Equal(t, 1, calls) +} diff --git a/wrapper.go b/wrapper.go index b91deb8a..ade69982 100644 --- a/wrapper.go +++ b/wrapper.go @@ -699,7 +699,7 @@ func (r *resticWrapper) sendMonitoring(sections []config.SendMonitoringSection, for i, section := range sections { clog.Debugf("starting %q from %s %d/%d", sendType, command, i+1, len(sections)) term.FlushAllOutput() - err := r.sender.Send(section, r.getContextWithError(err)) + err := r.sender.Send(section, r.getContextWithError(err), r.profile.GetEnvironment(true)) if err != nil { clog.Warningf("%q returned an error: %s", sendType, err.Error()) }