Skip to content

Commit 539a65e

Browse files
feat(client): optimize json encoder for internal types
1 parent 39ed372 commit 539a65e

7 files changed

Lines changed: 260 additions & 37 deletions

File tree

internal/encoding/json/encode.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,21 @@ import (
173173
// JSON cannot represent cyclic data structures and Marshal does not
174174
// handle them. Passing cyclic structures to Marshal will result in
175175
// an error.
176-
func Marshal(v any) ([]byte, error) {
176+
// EDIT(begin): add optimization options
177+
func Marshal(v any, opts ...Option) ([]byte, error) {
178+
// EDIT(end): add optimization options
177179
e := newEncodeState()
178180
defer encodeStatePool.Put(e)
179181

180-
// SHIM(begin): don't escape HTML by default
181-
err := e.marshal(v, encOpts{escapeHTML: shims.EscapeHTMLByDefault})
182+
// EDIT(begin): don't escape HTML by default, and apply options
183+
encOpts := encOpts{escapeHTML: shims.EscapeHTMLByDefault}
184+
if opts != nil {
185+
encOpts = encOpts.apply(opts...)
186+
}
187+
err := e.marshal(v, encOpts)
182188
// ORIGINAL:
183189
// err := e.marshal(v, encOpts{escapeHTML: true})
184-
// SHIM(end)
190+
// EDIT(end)
185191
if err != nil {
186192
return nil, err
187193
}
@@ -352,6 +358,9 @@ type encOpts struct {
352358
// EDIT(begin): save the timefmt
353359
timefmt string
354360
// EDIT(end)
361+
// EDIT(begin): add optimization to skip compaction
362+
skipCompaction bool
363+
// EDIT(end)
355364
}
356365

357366
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
@@ -483,7 +492,7 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
483492
if err == nil {
484493
e.Grow(len(b))
485494
out := e.AvailableBuffer()
486-
out, err = appendCompact(out, b, opts.escapeHTML)
495+
out, err = appendCompact(out, b, opts)
487496
e.Buffer.Write(out)
488497
}
489498
if err != nil {
@@ -509,7 +518,7 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
509518
if err == nil {
510519
e.Grow(len(b))
511520
out := e.AvailableBuffer()
512-
out, err = appendCompact(out, b, opts.escapeHTML)
521+
out, err = appendCompact(out, b, opts)
513522
e.Buffer.Write(out)
514523
}
515524
if err != nil {

internal/encoding/json/indent.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
package json
66

7-
import "bytes"
7+
import (
8+
"bytes"
9+
)
810

911
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
1012
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
@@ -41,12 +43,21 @@ func appendHTMLEscape(dst, src []byte) []byte {
4143
func Compact(dst *bytes.Buffer, src []byte) error {
4244
dst.Grow(len(src))
4345
b := dst.AvailableBuffer()
44-
b, err := appendCompact(b, src, false)
46+
b, err := appendCompact(b, src, encOpts{})
4547
dst.Write(b)
4648
return err
4749
}
4850

49-
func appendCompact(dst, src []byte, escape bool) ([]byte, error) {
51+
func appendCompact(dst, src []byte, opts encOpts) ([]byte, error) {
52+
// EDIT(begin): optimize for skipCompaction
53+
if opts.skipCompaction {
54+
dst = append(dst, src...)
55+
return dst, nil
56+
}
57+
58+
escape := opts.escapeHTML
59+
// EDIT(end)
60+
5061
origLen := len(dst)
5162
scan := newScanner()
5263
defer freeScanner(scan)

internal/encoding/json/opt.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// EDIT(begin): add custom options for JSON encoding
2+
package json
3+
4+
type Option func(*encOpts)
5+
6+
// Every time a sub-type of [json.Marshaler] is encountered,
7+
// skip a redundant and costly compaction step, trust it to self-compact.
8+
//
9+
// This is a divergence from the standard library behavior, and is only guaranteed
10+
// safe with SDK types.
11+
func WithSkipCompaction(b bool) Option {
12+
return func(eos *encOpts) {
13+
eos.skipCompaction = true
14+
}
15+
}
16+
17+
func (eos encOpts) apply(opts ...Option) encOpts {
18+
for _, opt := range opts {
19+
opt(&eos)
20+
}
21+
return eos
22+
}
23+
24+
// EDIT(end)

internal/encoding/json/stream.go

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package json
66

77
import (
88
"bytes"
9-
"errors"
109
"io"
1110
)
1211

@@ -253,30 +252,34 @@ func (enc *Encoder) SetEscapeHTML(on bool) {
253252
enc.escapeHTML = on
254253
}
255254

256-
// RawMessage is a raw encoded JSON value.
257-
// It implements [Marshaler] and [Unmarshaler] and can
258-
// be used to delay JSON decoding or precompute a JSON encoding.
259-
type RawMessage []byte
260-
261-
// MarshalJSON returns m as the JSON encoding of m.
262-
func (m RawMessage) MarshalJSON() ([]byte, error) {
263-
if m == nil {
264-
return []byte("null"), nil
265-
}
266-
return m, nil
267-
}
268-
269-
// UnmarshalJSON sets *m to a copy of data.
270-
func (m *RawMessage) UnmarshalJSON(data []byte) error {
271-
if m == nil {
272-
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
273-
}
274-
*m = append((*m)[0:0], data...)
275-
return nil
276-
}
277-
278-
var _ Marshaler = (*RawMessage)(nil)
279-
var _ Unmarshaler = (*RawMessage)(nil)
255+
// EDIT(begin): remove RawMessage
256+
//
257+
// // RawMessage is a raw encoded JSON value.
258+
// // It implements [Marshaler] and [Unmarshaler] and can
259+
// // be used to delay JSON decoding or precompute a JSON encoding.
260+
// type RawMessage []byte
261+
//
262+
// // MarshalJSON returns m as the JSON encoding of m.
263+
// func (m RawMessage) MarshalJSON() ([]byte, error) {
264+
// if m == nil {
265+
// return []byte("null"), nil
266+
// }
267+
// return m, nil
268+
// }
269+
//
270+
// // UnmarshalJSON sets *m to a copy of data.
271+
// func (m *RawMessage) UnmarshalJSON(data []byte) error {
272+
// if m == nil {
273+
// return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
274+
// }
275+
// *m = append((*m)[0:0], data...)
276+
// return nil
277+
// }
278+
//
279+
// var _ Marshaler = (*RawMessage)(nil)
280+
// var _ Unmarshaler = (*RawMessage)(nil)
281+
//
282+
// EDIT(end)
280283

281284
// A Token holds a value of one of these types:
282285
//

internal/encoding/json/time.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func timeMarshalEncoder(e *encodeState, v reflect.Value, opts encOpts) bool {
5050
if b != nil {
5151
e.Grow(len(b))
5252
out := e.AvailableBuffer()
53-
out, _ = appendCompact(out, b, opts.escapeHTML)
53+
out, _ = appendCompact(out, b, opts)
5454
e.Buffer.Write(out)
5555
return true
5656
}

packages/param/encoder.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func MarshalWithExtras[T ParamStruct, R any](f T, underlying any, extras map[str
6666
} else if ovr, ok := f.Overrides(); ok {
6767
return shimjson.Marshal(ovr)
6868
} else {
69-
return shimjson.Marshal(underlying)
69+
return shimjson.Marshal(underlying, shimjson.WithSkipCompaction(true))
7070
}
7171
}
7272

@@ -96,7 +96,7 @@ func MarshalUnion[T ParamStruct](metadata T, variants ...any) ([]byte, error) {
9696
Err: fmt.Errorf("expected union to have only one present variant, got %d", nPresent),
9797
}
9898
}
99-
return shimjson.Marshal(variants[presentIdx])
99+
return shimjson.Marshal(variants[presentIdx], shimjson.WithSkipCompaction(true))
100100
}
101101

102102
// typeFor is shimmed from Go 1.23 "reflect" package

packages/param/encoder_test.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package param_test
22

33
import (
4+
"bytes"
45
"encoding/json"
6+
"reflect"
57
"testing"
68
"time"
79

10+
shimjson "github.com/browserbase/stagehand-go/v3/internal/encoding/json"
811
"github.com/browserbase/stagehand-go/v3/packages/param"
912
)
1013

@@ -375,3 +378,176 @@ func TestNullStructUnion(t *testing.T) {
375378
t.Fatalf("expected null, received %s", string(b))
376379
}
377380
}
381+
382+
//
383+
// Compaction optimization
384+
//
385+
386+
type NonCompactedDoubleParent struct {
387+
Prop string `json:"prop"`
388+
Parent NonCompactedParent `json:"parent"`
389+
390+
param.APIObject
391+
}
392+
393+
type NonCompactedParent struct {
394+
BadChild NonCompacted `json:"bad_child"`
395+
396+
param.APIObject
397+
}
398+
399+
type NonCompacted struct {
400+
Raw string
401+
402+
param.APIObject
403+
}
404+
405+
func (a NonCompactedDoubleParent) MarshalJSON() ([]byte, error) {
406+
type shadow NonCompactedDoubleParent
407+
return param.MarshalObject(a, (*shadow)(&a))
408+
}
409+
410+
func (a NonCompactedParent) MarshalJSON() ([]byte, error) {
411+
type shadow NonCompactedParent
412+
return param.MarshalObject(a, (*shadow)(&a))
413+
}
414+
415+
func (a NonCompacted) MarshalJSON() ([]byte, error) {
416+
if a.Raw == "" {
417+
a.Raw = nonCompactedRaw
418+
}
419+
return []byte(a.Raw), nil
420+
}
421+
422+
var nonCompactedRaw string = ` { "foo": "bar" } `
423+
424+
func TestAppendCompactBroken(t *testing.T) {
425+
tests := map[string]struct {
426+
value json.Marshaler
427+
}{
428+
"red/illegal-json": {
429+
NonCompacted{Raw: `{ "broken": "json" `},
430+
},
431+
"red/nested-with-illegal-json": {
432+
NonCompactedParent{BadChild: NonCompacted{
433+
Raw: `{ "broken": "json" `,
434+
}},
435+
},
436+
}
437+
438+
for name, test := range tests {
439+
t.Run(name, func(t *testing.T) {
440+
v, err := json.Marshal(test.value)
441+
if err == nil {
442+
t.Fatal("expected error got", v)
443+
}
444+
})
445+
}
446+
}
447+
448+
// TestAppendCompact validates an optimization for internal SDK types to
449+
// avoid O(keys^2) iteration over each JSON object.
450+
//
451+
// It's possible to intentionally trigger this behavior as both a user and
452+
// SDK developer. However, the edge case is quite pathological and requires
453+
// calling [json.Marshaler.MarshalJSON] rather than [json.Marshal].
454+
func TestAppendCompact(t *testing.T) {
455+
456+
tests := map[string]struct {
457+
value json.Marshaler
458+
expected string
459+
}{
460+
//
461+
// Non-compacted cases
462+
//
463+
// Note this is how to exploit the compacter to fail, you must call [json.Marshaler.MarshalJSON] rather than [json.Marshal].
464+
// The type must also embed [param.APIObject] and return non-compacted JSON.
465+
//
466+
467+
"no-compact/fails-compaction": {
468+
NonCompacted{Raw: nonCompactedRaw},
469+
nonCompactedRaw,
470+
},
471+
"no-compact/nested-with-bad-child": {
472+
NonCompactedParent{BadChild: NonCompacted{
473+
Raw: nonCompactedRaw,
474+
}},
475+
`{"bad_child":` + nonCompactedRaw + `}`,
476+
},
477+
"no-compact/double-nested-with-bad-child": {
478+
NonCompactedDoubleParent{Prop: "1", Parent: NonCompactedParent{BadChild: NonCompacted{
479+
Raw: nonCompactedRaw,
480+
}}},
481+
`{"prop":"1","parent":{"bad_child":` + nonCompactedRaw + `}}`,
482+
},
483+
484+
//
485+
// Compacted cases
486+
//
487+
488+
"override/spaces-within": {
489+
param.Override[NonCompactedDoubleParent](json.RawMessage(`{"com": "pact"}`)),
490+
`{"com":"pact"}`,
491+
},
492+
"override/spaces-after": {
493+
param.Override[NonCompactedDoubleParent](json.RawMessage(`{"com":"pact"} `)),
494+
`{"com":"pact"}`,
495+
},
496+
"override/spaces-before": {
497+
param.Override[NonCompactedDoubleParent](json.RawMessage(` {"com":"pact"}`)),
498+
`{"com":"pact"}`,
499+
},
500+
"override/spaces-around": {
501+
param.Override[NonCompactedDoubleParent](json.RawMessage(` { "com": "pact" }`)),
502+
`{"com":"pact"}`,
503+
},
504+
"override/override-with-nested": {
505+
param.Override[NonCompactedDoubleParent](NonCompactedParent{}),
506+
`{"bad_child":{"foo":"bar"}}`,
507+
},
508+
"override/override-with-non-compacted": {
509+
param.Override[NonCompactedDoubleParent](NonCompacted{}),
510+
`{"foo":"bar"}`,
511+
},
512+
}
513+
514+
for name, test := range tests {
515+
t.Run(name+"/marshal-json", func(t *testing.T) {
516+
b, err := test.value.MarshalJSON()
517+
if err != nil {
518+
t.Fatalf("didn't expect error %v, expected %s", err, test.expected)
519+
}
520+
if string(b) != test.expected {
521+
t.Fatalf("expected %s (%s), received %s", test.expected, reflect.TypeOf(test.value), string(b))
522+
}
523+
})
524+
525+
t.Run(name+"/json-marshal", func(t *testing.T) {
526+
b, err := json.Marshal(test.value)
527+
if err != nil {
528+
t.Fatalf("didn't expect error %v, expected %s", err, test.expected)
529+
}
530+
531+
// expected output of JSON Marshal should always be compacted
532+
var compactedExpected bytes.Buffer
533+
err = json.Compact(&compactedExpected, []byte(test.expected))
534+
if err != nil {
535+
t.Fatalf("didn't expect error %v, expected %s", err, test.expected)
536+
}
537+
538+
if string(b) != compactedExpected.String() {
539+
t.Fatalf("expected %s (%s), received %s", test.expected, reflect.TypeOf(test.value), string(b))
540+
}
541+
})
542+
543+
t.Run(name+"/shimjson-marshal", func(t *testing.T) {
544+
b, err := shimjson.Marshal(test.value)
545+
if err != nil {
546+
t.Fatalf("didn't expect error %v, expected %s", err, test.expected)
547+
}
548+
if string(b) != test.expected {
549+
t.Logf("expected %s (%s), received %s", test.expected, reflect.TypeOf(test.value), string(b))
550+
}
551+
})
552+
}
553+
}

0 commit comments

Comments
 (0)