Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions monitor/hook/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/creativeprojects/clog"
"github.com/creativeprojects/resticprofile/config"
"github.com/creativeprojects/resticprofile/constants"
"github.com/creativeprojects/resticprofile/util"
"github.com/creativeprojects/resticprofile/util/templates"
)

Expand Down Expand Up @@ -71,12 +72,17 @@ func NewSender(certificates []string, userAgent string, timeout time.Duration, d
}
}

func (s *Sender) Send(cfg config.SendMonitoringSection, ctx Context) error {
func (s *Sender) Send(cfg config.SendMonitoringSection, ctx Context, env *util.Environment) error {
if cfg.URL.Value() == "" {
return errors.New("URL field is empty")
}
url := resolveURL(cfg.URL.Value(), ctx)
publicUrl := resolveURL(cfg.URL.String(), ctx)

if env == nil {
env = util.NewDefaultEnvironment(os.Environ()...)
}

url := resolveURL(cfg.URL.Value(), ctx, env)
publicUrl := resolveURL(cfg.URL.String(), ctx, env)
method := cfg.Method
if method == "" {
method = http.MethodGet
Expand All @@ -94,7 +100,7 @@ func (s *Sender) Send(cfg config.SendMonitoringSection, ctx Context) error {
bodyReader = bytes.NewBufferString(body)
}
if cfg.Body != "" {
body = resolveBody(cfg.Body, ctx)
body = resolveBody(cfg.Body, ctx, env)
bodyReader = bytes.NewBufferString(body)
}

Expand Down Expand Up @@ -203,7 +209,7 @@ func getRootCAs(certificates []string) *x509.CertPool {
return caCertPool
}

func resolveBody(body string, ctx Context) string {
func resolveBody(body string, ctx Context, env *util.Environment) string {
return os.Expand(body, func(s string) string {
switch s {
case constants.EnvProfileName:
Expand All @@ -228,12 +234,12 @@ func resolveBody(body string, ctx Context) string {
return "$" // allow to escape "$" as "$$"

default:
return os.Getenv(s)
return env.Get(s)
}
})
}

func resolveURL(url string, ctx Context) string {
func resolveURL(url string, ctx Context, env *util.Environment) string {
return os.Expand(url, func(s string) string {
switch s {
case constants.EnvProfileName:
Expand All @@ -258,7 +264,7 @@ func resolveURL(url string, ctx Context) string {
return "$" // allow to escape "$" as "$$"

default:
return os.Getenv(s)
return env.Get(s)
}
})
}
Expand Down
59 changes: 45 additions & 14 deletions monitor/hook/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (

"github.com/creativeprojects/clog"
"github.com/creativeprojects/resticprofile/config"
"github.com/creativeprojects/resticprofile/constants"
"github.com/creativeprojects/resticprofile/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/creativeprojects/resticprofile/constants"
)

func TestSend(t *testing.T) {
Expand Down Expand Up @@ -95,7 +96,7 @@ func TestSend(t *testing.T) {
}

sender := NewSender(nil, "resticprofile_test", 10*time.Second, false)
err := sender.Send(testCase.cfg, ctx)
err := sender.Send(testCase.cfg, ctx, nil)
assert.NoError(t, err)

assert.Equal(t, testCase.calls, calls)
Expand All @@ -114,7 +115,7 @@ func TestDryRun(t *testing.T) {
sender := NewSender(nil, "", time.Second, true)
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
}, Context{})
}, Context{}, nil)
assert.NoError(t, err)

assert.Equal(t, uint32(0), atomic.LoadUint32(&calls))
Expand All @@ -133,7 +134,7 @@ func TestSenderTimeout(t *testing.T) {
sender := NewSender(nil, "resticprofile_test", 300*time.Millisecond, false)
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
}, Context{})
}, Context{}, nil)
assert.Error(t, err)

assert.Equal(t, uint32(1), atomic.LoadUint32(&startedCalls))
Expand All @@ -151,15 +152,15 @@ func TestInsecureRequests(t *testing.T) {
// 1: request will fail TLS
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
}, Context{})
}, Context{}, nil)
assert.Error(t, err)
assert.Equal(t, 0, calls)

// 2: request allowing bad certificate
err = sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
SkipTLS: true,
}, Context{})
}, Context{}, nil)
assert.NoError(t, err)
assert.Equal(t, 1, calls)
}
Expand All @@ -179,7 +180,7 @@ func TestRequestWithCA(t *testing.T) {
// 1: request will fail TLS
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
}, Context{})
}, Context{}, nil)
assert.Error(t, err)
assert.Equal(t, 0, calls)

Expand All @@ -196,7 +197,7 @@ func TestRequestWithCA(t *testing.T) {
sender = NewSender([]string{filename}, "resticprofile_test", 300*time.Millisecond, false)
err = sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
}, Context{})
}, Context{}, nil)
assert.NoError(t, err)
assert.Equal(t, 1, calls)
}
Expand All @@ -210,7 +211,7 @@ func TestFailedRequest(t *testing.T) {
sender := NewSender(nil, "resticprofile_test", 300*time.Millisecond, false)
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(server.URL),
}, Context{})
}, Context{}, nil)
assert.Error(t, err)
}

Expand All @@ -230,7 +231,7 @@ func TestUserAgent(t *testing.T) {
Headers: []config.SendMonitoringHeader{
{Name: agentHeader, Value: config.NewConfidentialValue(testAgent)},
},
}, Context{})
}, Context{}, nil)
assert.NoError(t, err)
assert.Equal(t, 1, calls)
}
Expand Down Expand Up @@ -269,7 +270,7 @@ func TestConfidentialURL(t *testing.T) {
config.ProcessConfidentialValues(profile)

sender := NewSender(nil, "", 300*time.Millisecond, false)
err := sender.Send(profile.Backup.SendBefore[0], Context{})
err := sender.Send(profile.Backup.SendBefore[0], Context{}, nil)
require.NoError(t, err)
assert.Equal(t, 1, calls)
}
Expand Down Expand Up @@ -320,7 +321,7 @@ func TestURLEncoding(t *testing.T) {
sender := NewSender(nil, "", 300*time.Millisecond, false)
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue(serverURL),
}, ctx)
}, ctx, nil)
assert.NoError(t, err)
assert.Equal(t, 1, calls)
}
Expand Down Expand Up @@ -359,7 +360,7 @@ func TestConfidentialHeader(t *testing.T) {
config.ProcessConfidentialValues(profile)

sender := NewSender(nil, "", 300*time.Millisecond, false)
err := sender.Send(profile.Backup.SendBefore[0], Context{})
err := sender.Send(profile.Backup.SendBefore[0], Context{}, nil)
require.NoError(t, err)
assert.Equal(t, 1, calls)
}
Expand Down Expand Up @@ -393,7 +394,7 @@ func TestParseTemplate(t *testing.T) {
URL: config.NewConfidentialValue(server.URL),
Method: http.MethodPost,
BodyTemplate: filename,
}, ctx)
}, ctx, nil)
assert.NoError(t, err)
}

Expand All @@ -418,3 +419,33 @@ func TestResponseSanitizer(t *testing.T) {
assert.Equal(t, test[1], responseContentSanitizer.ReplaceAllString(test[0], " "), "test #%d", i)
}
}

func TestCustomEnv(t *testing.T) {
calls := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buffer := bytes.Buffer{}
_, err := buffer.ReadFrom(r.Body)
assert.NoError(t, err)
r.Body.Close()

assert.Equal(t, "some test value\n", buffer.String())

calls++
}))
defer server.Close()

t.Setenv("TEST_MONITOR_URL", "should never be read")
t.Setenv("SOME_OS_ENV", "should never be read")

env := util.NewDefaultEnvironment()
env.Put("TEST_MONITOR_URL", server.URL)
env.Put("TEST_BODY_VALUE", "some test value")

sender := NewSender(nil, "", 300*time.Millisecond, false)
err := sender.Send(config.SendMonitoringSection{
URL: config.NewConfidentialValue("$TEST_MONITOR_URL"),
Body: "$TEST_BODY_VALUE\n$SOME_OS_ENV",
}, Context{}, env)
assert.NoError(t, err)
assert.Equal(t, 1, calls)
}
2 changes: 1 addition & 1 deletion wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func (r *resticWrapper) sendMonitoring(sections []config.SendMonitoringSection,
for i, section := range sections {
clog.Debugf("starting %q from %s %d/%d", sendType, command, i+1, len(sections))
term.FlushAllOutput()
err := r.sender.Send(section, r.getContextWithError(err))
err := r.sender.Send(section, r.getContextWithError(err), r.profile.GetEnvironment(true))
if err != nil {
clog.Warningf("%q returned an error: %s", sendType, err.Error())
}
Expand Down
Loading