Skip to content

Commit 51d877d

Browse files
feat(client): allow overriding unions
1 parent f0b66f8 commit 51d877d

2 files changed

Lines changed: 44 additions & 2 deletions

File tree

packages/param/encoder.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func MarshalWithExtras[T ParamStruct, R any](f T, underlying any, extras map[str
6868
// MarshalUnion uses a shimmed 'encoding/json' from Go 1.24, to support the 'omitzero' tag
6969
//
7070
// Stability for the API of MarshalUnion is not guaranteed.
71-
func MarshalUnion[T any](variants ...any) ([]byte, error) {
71+
func MarshalUnion[T ParamStruct](metadata T, variants ...any) ([]byte, error) {
7272
nPresent := 0
7373
presentIdx := -1
7474
for i, variant := range variants {
@@ -78,6 +78,9 @@ func MarshalUnion[T any](variants ...any) ([]byte, error) {
7878
}
7979
}
8080
if nPresent == 0 || presentIdx == -1 {
81+
if ovr, ok := metadata.Overrides(); ok {
82+
return shimjson.Marshal(ovr)
83+
}
8184
return []byte(`null`), nil
8285
} else if nPresent > 1 {
8386
return nil, &json.MarshalerError{

packages/param/encoder_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,11 @@ func TestExtraFieldsForceOmitted(t *testing.T) {
218218
type UnionWithDates struct {
219219
OfDate param.Opt[time.Time]
220220
OfTime param.Opt[time.Time]
221+
param.APIUnion
221222
}
222223

223224
func (r UnionWithDates) MarshalJSON() (data []byte, err error) {
224-
return param.MarshalUnion[UnionWithDates](param.EncodedAsDate(r.OfDate), r.OfTime)
225+
return param.MarshalUnion(r, param.EncodedAsDate(r.OfDate), r.OfTime)
225226
}
226227

227228
func TestUnionDateMarshal(t *testing.T) {
@@ -324,3 +325,41 @@ func TestOptionalInterfaceAssignability(t *testing.T) {
324325

325326
notOpt.implOpt() // silence the warning
326327
}
328+
329+
type PrimitiveUnion struct {
330+
OfString param.Opt[string]
331+
OfInt param.Opt[int]
332+
param.APIUnion
333+
}
334+
335+
func (p PrimitiveUnion) MarshalJSON() (data []byte, err error) {
336+
return param.MarshalUnion(p, p.OfString, p.OfInt)
337+
}
338+
339+
func TestOverriddenUnion(t *testing.T) {
340+
tests := map[string]struct {
341+
value PrimitiveUnion
342+
expected string
343+
}{
344+
"string": {
345+
param.Override[PrimitiveUnion](json.RawMessage(`"hello"`)),
346+
`"hello"`,
347+
},
348+
"int": {
349+
param.Override[PrimitiveUnion](json.RawMessage(`42`)),
350+
`42`,
351+
},
352+
}
353+
354+
for name, test := range tests {
355+
t.Run(name, func(t *testing.T) {
356+
b, err := json.Marshal(test.value)
357+
if err != nil {
358+
t.Fatalf("didn't expect error %v, expected %s", err, test.expected)
359+
}
360+
if string(b) != test.expected {
361+
t.Fatalf("expected %s, received %s", test.expected, string(b))
362+
}
363+
})
364+
}
365+
}

0 commit comments

Comments
 (0)