Skip to content

Commit 97ca0aa

Browse files
committed
fix(config): block env overrides for internal fields
1 parent f829ffe commit 97ca0aa

File tree

2 files changed

+78
-5
lines changed

2 files changed

+78
-5
lines changed

pkg/config/config.go

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
_ "embed"
7+
"encoding"
78
"encoding/base64"
89
"encoding/json"
910
"fmt"
@@ -16,13 +17,15 @@ import (
1617
"os"
1718
"path"
1819
"path/filepath"
20+
"reflect"
1921
"regexp"
2022
"slices"
2123
"sort"
2224
"strconv"
2325
"strings"
2426
"text/template"
2527
"time"
28+
"unicode"
2629

2730
"github.com/BurntSushi/toml"
2831
"github.com/docker/go-units"
@@ -458,17 +461,16 @@ func (c *config) Eject(w io.Writer) error {
458461

459462
// Loads custom config file to struct fields tagged with toml.
460463
func (c *config) loadFromFile(filename string, fsys fs.FS) error {
461-
v := viper.NewWithOptions(
462-
viper.ExperimentalBindStruct(),
463-
viper.EnvKeyReplacer(strings.NewReplacer(".", "_")),
464-
)
464+
v := viper.New()
465465
v.SetEnvPrefix("SUPABASE")
466-
v.AutomaticEnv()
467466
if err := c.mergeDefaultValues(v); err != nil {
468467
return err
469468
} else if err := mergeFileConfig(v, filename, fsys); err != nil {
470469
return err
471470
}
471+
if err := bindUserConfigEnv(v, reflect.TypeOf(*c), ""); err != nil {
472+
return err
473+
}
472474
// Find [remotes.*] block to override base config
473475
idToName := map[string]string{}
474476
for name, remote := range v.GetStringMap("remotes") {
@@ -488,6 +490,62 @@ func (c *config) loadFromFile(filename string, fsys fs.FS) error {
488490
return c.load(v)
489491
}
490492

493+
func bindUserConfigEnv(v *viper.Viper, t reflect.Type, prefix string) error {
494+
textUnmarshaler := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
495+
for i := 0; i < t.NumField(); i++ {
496+
field := t.Field(i)
497+
if field.PkgPath != "" {
498+
continue
499+
}
500+
if field.Anonymous {
501+
if err := bindUserConfigEnv(v, field.Type, prefix); err != nil {
502+
return err
503+
}
504+
continue
505+
}
506+
if tag := strings.Split(field.Tag.Get("toml"), ",")[0]; tag == "-" {
507+
continue
508+
}
509+
key := strings.Split(field.Tag.Get("json"), ",")[0]
510+
if key == "-" {
511+
continue
512+
} else if key == "" {
513+
key = toSnakeCase(field.Name)
514+
}
515+
if len(prefix) > 0 {
516+
key = prefix + "." + key
517+
}
518+
fieldType := field.Type
519+
if fieldType.Kind() == reflect.Ptr {
520+
fieldType = fieldType.Elem()
521+
}
522+
if fieldType.Kind() == reflect.Struct && !fieldType.Implements(textUnmarshaler) && !reflect.PointerTo(fieldType).Implements(textUnmarshaler) {
523+
if err := bindUserConfigEnv(v, fieldType, key); err != nil {
524+
return err
525+
}
526+
continue
527+
}
528+
if err := v.BindEnv(key); err != nil {
529+
return errors.Errorf("failed to bind env override for %s: %w", key, err)
530+
}
531+
}
532+
return nil
533+
}
534+
535+
func toSnakeCase(s string) string {
536+
var b strings.Builder
537+
for i, r := range s {
538+
if unicode.IsUpper(r) {
539+
if i > 0 {
540+
b.WriteByte('_')
541+
}
542+
r = unicode.ToLower(r)
543+
}
544+
b.WriteRune(r)
545+
}
546+
return b.String()
547+
}
548+
491549
func (c *config) mergeDefaultValues(v *viper.Viper) error {
492550
v.SetConfigType("toml")
493551
var buf bytes.Buffer

pkg/config/config_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,21 @@ func TestRemoteOverride(t *testing.T) {
154154
})
155155
}
156156

157+
func TestEnvOverridesSkipInternalFields(t *testing.T) {
158+
config := NewConfig()
159+
fsys := fs.MapFS{
160+
"supabase/config.toml": &fs.MapFile{Data: testInitConfigEmbed},
161+
"supabase/templates/invite.html": &fs.MapFile{},
162+
}
163+
t.Setenv("SUPABASE_HOSTNAME", "evil.example.com")
164+
t.Setenv("SUPABASE_AUTH_SITE_URL", "http://preview.com")
165+
t.Setenv("AUTH_SEND_SMS_SECRETS", "v1,whsec_aWxpa2VzdXBhYmFzZXZlcnltdWNoYW5kaWhvcGV5b3Vkb3Rvbw==")
166+
167+
require.NoError(t, config.Load("", fsys))
168+
assert.Equal(t, "127.0.0.1", config.Hostname)
169+
assert.Equal(t, "http://preview.com", config.Auth.SiteUrl)
170+
}
171+
157172
func TestFileSizeLimitConfigParsing(t *testing.T) {
158173
t.Run("test file size limit parsing number", func(t *testing.T) {
159174
var testConfig config

0 commit comments

Comments
 (0)