Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions sdk/configwatcher/fieldgroups.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package configwatcher

import (
"fmt"
"reflect"
"time"
)

// Group registers a set of watched fields mapped from a struct's `decree` tags.
// Use [Watcher.NewGroup] to create one.
type Group struct {
getters []func(rv reflect.Value) // one per tagged field
typ reflect.Type // struct type for validation
}

// NewGroup registers all `decree`-tagged fields in the struct pointed to by s
// with the watcher and returns a Group. The struct pointer must remain valid
// for the lifetime of the watcher.
//
// Example:
//
// type AppConfig struct {
// Name string `decree:"app.name"`
// Debug bool `decree:"app.debug"`
// }
// g, err := w.NewGroup(ctx, &AppConfig{})
func (w *Watcher) NewGroup(s any) (*Group, error) {
rv := reflect.ValueOf(s)
if rv.Kind() != reflect.Ptr || rv.IsNil() || rv.Elem().Kind() != reflect.Struct {
return nil, fmt.Errorf("configwatcher: NewGroup: s must be a non-nil pointer to a struct, got %T", s)
}
rv = rv.Elem()
rt := rv.Type()

g := &Group{typ: rt}

for i := range rv.NumField() {
field := rt.Field(i)
tag := field.Tag.Get("decree")
if tag == "" || tag == "-" {
continue
}
fv := rv.Field(i)
if !fv.CanSet() {
continue
}

getter, err := registerGroupField(w, tag, fv)
if err != nil {
return nil, fmt.Errorf("configwatcher: NewGroup: field %s (decree:%q): %w", field.Name, tag, err)
}
idx := i
g.getters = append(g.getters, func(target reflect.Value) {
getter(target.Field(idx))
})
}
return g, nil
}

// Fill populates the struct pointed to by s with the current values of all
// watched fields. s must be the same type that was passed to NewGroup.
func (g *Group) Fill(s any) error {
rv := reflect.ValueOf(s)
if rv.Kind() != reflect.Ptr || rv.IsNil() || rv.Elem().Kind() != reflect.Struct {
return fmt.Errorf("configwatcher: Group.Fill: s must be a non-nil pointer to a struct")
}
rv = rv.Elem()
if rv.Type() != g.typ {
return fmt.Errorf("configwatcher: Group.Fill: type mismatch: got %s, want %s", rv.Type(), g.typ)
}
for _, get := range g.getters {
get(rv)
}
return nil
}

var durType = reflect.TypeOf(time.Duration(0))

func registerGroupField(w *Watcher, path string, fv reflect.Value) (func(dst reflect.Value), error) {
switch {
case fv.Type() == durType:
val, err := w.Duration(path, 0)
if err != nil {
return nil, err
}
return func(dst reflect.Value) { dst.SetInt(int64(val.Get())) }, nil

case fv.Kind() == reflect.String:
val, err := w.String(path, "")
if err != nil {
return nil, err
}
return func(dst reflect.Value) { dst.SetString(val.Get()) }, nil

case fv.Kind() == reflect.Bool:
val, err := w.Bool(path, false)
if err != nil {
return nil, err
}
return func(dst reflect.Value) { dst.SetBool(val.Get()) }, nil

case fv.Kind() >= reflect.Int && fv.Kind() <= reflect.Int64:
val, err := w.Int(path, 0)
if err != nil {
return nil, err
}
return func(dst reflect.Value) { dst.SetInt(val.Get()) }, nil

case fv.Kind() == reflect.Float64 || fv.Kind() == reflect.Float32:
val, err := w.Float(path, 0)
if err != nil {
return nil, err
}
return func(dst reflect.Value) { dst.SetFloat(val.Get()) }, nil

default:
return nil, fmt.Errorf("unsupported field type %s", fv.Type())
}
}
113 changes: 113 additions & 0 deletions sdk/configwatcher/fieldgroups_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package configwatcher

import (
"context"
"testing"
"time"

"github.com/opendecree/decree/sdk/configclient"
)

func TestNewGroup_FillAllTypes(t *testing.T) {
tr := &mockTransport{
getConfigFn: func(_ context.Context, _ *configclient.GetConfigRequest) (*configclient.GetConfigResponse, error) {
return &configclient.GetConfigResponse{
Values: []configclient.ConfigValue{
{FieldPath: "app.name", Value: configclient.StringVal("myapp")},
{FieldPath: "app.debug", Value: configclient.BoolVal(true)},
{FieldPath: "app.count", Value: configclient.IntVal(42)},
{FieldPath: "app.rate", Value: configclient.FloatVal(1.5)},
{FieldPath: "app.timeout", Value: configclient.DurationVal(5 * time.Second)},
},
}, nil
},
subscribeFn: func(ctx context.Context, _ *configclient.SubscribeRequest) (configclient.Subscription, error) {
return newMockSubscription(ctx), nil
},
}

type Config struct {
Name string `decree:"app.name"`
Debug bool `decree:"app.debug"`
Count int64 `decree:"app.count"`
Rate float64 `decree:"app.rate"`
Timeout time.Duration `decree:"app.timeout"`
Ignored string
Skipped string `decree:"-"`
}

w := New(tr, "t1")
g, err := w.NewGroup(&Config{})
if err != nil {
t.Fatalf("NewGroup: %v", err)
}
if err := w.Start(context.Background()); err != nil {
t.Fatalf("Start: %v", err)
}
defer w.Close()

var cfg Config
if err := g.Fill(&cfg); err != nil {
t.Fatalf("Fill: %v", err)
}

if cfg.Name != "myapp" {
t.Errorf("Name = %q, want %q", cfg.Name, "myapp")
}
if !cfg.Debug {
t.Error("Debug = false, want true")
}
if cfg.Count != 42 {
t.Errorf("Count = %d, want 42", cfg.Count)
}
if cfg.Rate != 1.5 {
t.Errorf("Rate = %f, want 1.5", cfg.Rate)
}
if cfg.Timeout != 5*time.Second {
t.Errorf("Timeout = %v, want 5s", cfg.Timeout)
}
}

func TestNewGroup_NonPointerError(t *testing.T) {
tr := &mockTransport{
getConfigFn: func(_ context.Context, _ *configclient.GetConfigRequest) (*configclient.GetConfigResponse, error) {
return &configclient.GetConfigResponse{}, nil
},
subscribeFn: func(ctx context.Context, _ *configclient.SubscribeRequest) (configclient.Subscription, error) {
return newMockSubscription(ctx), nil
},
}
w := New(tr, "t1")
type S struct{}
_, err := w.NewGroup(S{})
if err == nil {
t.Error("expected error for non-pointer, got nil")
}
}

func TestGroup_Fill_TypeMismatch(t *testing.T) {
tr := &mockTransport{
getConfigFn: func(_ context.Context, _ *configclient.GetConfigRequest) (*configclient.GetConfigResponse, error) {
return &configclient.GetConfigResponse{}, nil
},
subscribeFn: func(ctx context.Context, _ *configclient.SubscribeRequest) (configclient.Subscription, error) {
return newMockSubscription(ctx), nil
},
}
w := New(tr, "t1")
type A struct {
X string `decree:"x"`
}
type B struct {
X string `decree:"x"`
}

g, err := w.NewGroup(&A{})
if err != nil {
t.Fatalf("NewGroup: %v", err)
}
var b B
if err := g.Fill(&b); err == nil {
t.Error("expected type mismatch error, got nil")
}
}
Loading