Skip to content

Commit e7f7deb

Browse files
committed
rework fusion framework based on immutable named types
This commit reworks downcast, upcast, and fuser to presume that named types are immutable. The changes for immutability are forthcoming but the fusion code now assumes this invariant. This new deisng involved fused types to now carry the named types instead of stripping them and deferring them to the fusion subtypes. This is key to allow the fusion runtime to manipulate named types. When support for recursive types is added, the recursive type will remain in the fusion and the concrete types will be fused below. With this approach, a fused recursive type will rarely need to be unfused. This commit also tightens up the algorithms for upcast/downcast so they should now be generally more reliable and can serve as a model for implementing their vam counterparts.
1 parent 75dc693 commit e7f7deb

13 files changed

Lines changed: 296 additions & 110 deletions

File tree

runtime/sam/expr/agg/fuser.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ func (f *Fuser) fuse(a, b super.Type) super.Type {
121121
}
122122
case *super.TypeNamed:
123123
if b, ok := b.(*super.TypeNamed); ok && a.Name == b.Name {
124+
if a.Type != b.Type {
125+
// The fusion algorithm does not handle named types that change.
126+
// We will soon maked such types immutable, but for now we just
127+
// return type error({}) to avoid any tests that might do this.
128+
recType := f.sctx.MustLookupTypeRecord([]super.Field{
129+
super.NewField(a.Name, a.Type),
130+
})
131+
return f.sctx.LookupTypeError(recType)
132+
}
124133
named, err := f.sctx.LookupTypeNamed(a.Name, f.fuse(a.Type, b.Type))
125134
if err != nil {
126135
panic(err)

runtime/sam/expr/function/defuse.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type defuse struct {
1616
func NewDefuse(sctx *super.Context) *defuse {
1717
return &defuse{
1818
sctx: sctx,
19-
downcast: &downcast{sctx: sctx},
19+
downcast: &downcast{sctx: sctx, name: "defuse"},
2020
has: make(map[super.Type]bool),
2121
}
2222
}
@@ -95,11 +95,11 @@ func (d *defuse) eval(in super.Value) super.Value {
9595
case *super.TypeUnion:
9696
return d.eval(in.DeunionIntoNameds())
9797
case *super.TypeFusion:
98-
_, subType := typ.Deref(d.sctx, in.Bytes())
99-
if out, ok := d.downcast.Cast(in, subType); ok {
100-
return out
98+
out, errVal := d.downcast.defuse(typ, in.Bytes())
99+
if errVal != nil {
100+
return *errVal
101101
}
102-
return d.sctx.WrapError("cannot defuse super value", in)
102+
return out
103103
default:
104104
// primitives, named types, enums
105105
// BTW, named types are a barrier to defuse.
Lines changed: 148 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
package function
22

33
import (
4+
"slices"
5+
46
"github.com/brimdata/super"
57
"github.com/brimdata/super/scode"
68
"github.com/brimdata/super/sup"
79
)
810

911
type downcast struct {
1012
sctx *super.Context
13+
name string
1114
}
1215

13-
func NewDowncast(sctx *super.Context) Caster {
14-
return &downcast{sctx}
16+
func NewDowncast(sctx *super.Context, name string) Caster {
17+
return &downcast{sctx, name}
1518
}
1619

1720
func (d *downcast) Call(args []super.Value) super.Value {
@@ -23,164 +26,238 @@ func (d *downcast) Call(args []super.Value) super.Value {
2326
if err != nil {
2427
panic(err)
2528
}
26-
val, ok := d.Cast(from, typ)
27-
if !ok {
28-
return d.sctx.WrapError("downcast: value not a supertype of "+sup.FormatType(typ), from)
29+
val, errVal := d.downcast(from.Type(), from.Bytes(), typ)
30+
if errVal != nil {
31+
return *errVal
2932
}
3033
return val
3134
}
3235

3336
func (d *downcast) Cast(from super.Value, to super.Type) (super.Value, bool) {
34-
var b scode.Builder
35-
if ok := d.downcast(&b, from.Type(), from.Bytes(), to); ok {
36-
return super.NewValue(to, b.Bytes().Body()), true
37-
}
38-
return super.Value{}, false
37+
val, errVal := d.downcast(from.Type(), from.Bytes(), to)
38+
return val, errVal == nil
3939
}
4040

41-
func (d *downcast) downcast(b *scode.Builder, typ super.Type, bytes scode.Bytes, to super.Type) bool {
42-
typ, bytes = deunion(typ, bytes)
43-
if superType, ok := typ.(*super.TypeFusion); ok {
44-
superBytes, _ := superType.Deref(d.sctx, bytes)
45-
return d.downcast(b, superType.Type, superBytes, to)
41+
func (d *downcast) downcast(typ super.Type, bytes scode.Bytes, to super.Type) (super.Value, *super.Value) {
42+
if _, ok := to.(*super.TypeUnion); !ok {
43+
if fusionType, ok := typ.(*super.TypeFusion); ok {
44+
superBytes, subtype := fusionType.Deref(d.sctx, bytes)
45+
return d.downcast(fusionType.Type, superBytes, subtype)
46+
}
4647
}
47-
typ = super.TypeUnder(typ)
48+
typ, bytes = deunion(typ, bytes)
4849
switch to := to.(type) {
4950
case *super.TypeRecord:
50-
return d.toRecord(b, typ, bytes, to)
51+
return d.toRecord(typ, bytes, to)
5152
case *super.TypeArray:
52-
return d.toArray(b, typ, bytes, to)
53+
return d.toArray(typ, bytes, to)
5354
case *super.TypeSet:
54-
return d.toSet(b, typ, bytes, to)
55+
return d.toSet(typ, bytes, to)
5556
case *super.TypeMap:
56-
return d.toMap(b, typ, bytes, to)
57+
return d.toMap(typ, bytes, to)
5758
case *super.TypeUnion:
58-
return d.toUnion(b, typ, bytes, to)
59+
return d.toUnion(typ, bytes, to)
5960
case *super.TypeError:
60-
return d.toError(b, typ, bytes, to)
61+
return d.toError(typ, bytes, to)
6162
case *super.TypeNamed:
62-
return d.downcast(b, typ, bytes, to.Type)
63+
return d.toNamed(typ, bytes, to)
6364
case *super.TypeFusion:
6465
// Can't downcast to a super type
65-
return false
66+
return super.Value{}, d.sctx.WrapError("downcast: cannot downcast to a fusion type", super.NewValue(typ, bytes)).Ptr()
6667
default:
6768
if typ == to {
68-
b.Append(bytes)
69-
return true
69+
return super.NewValue(typ, bytes), nil
70+
} else {
71+
typ, bytes := deunion(typ, bytes)
72+
if typ == to {
73+
return super.NewValue(typ, bytes), nil
74+
}
7075
}
71-
return false
76+
return super.Value{}, d.errMismatch(typ, bytes, to)
7277
}
7378
}
7479

75-
func (d *downcast) toRecord(b *scode.Builder, typ super.Type, bytes scode.Bytes, to *super.TypeRecord) bool {
80+
func (d *downcast) defuse(fusionType *super.TypeFusion, bytes scode.Bytes) (super.Value, *super.Value) {
81+
superBytes, subtype := fusionType.Deref(d.sctx, bytes)
82+
return d.downcast(fusionType.Type, superBytes, subtype)
83+
}
84+
85+
func (d *downcast) toRecord(typ super.Type, bytes scode.Bytes, to *super.TypeRecord) (super.Value, *super.Value) {
7686
fromType, ok := typ.(*super.TypeRecord)
7787
if !ok {
78-
return false
88+
return super.Value{}, d.errMismatch(typ, bytes, to)
7989
}
8090
var nones []int
8191
var optOff int
92+
b := scode.NewBuilder()
8293
b.BeginContainer()
83-
for _, toField := range to.Fields { // ranging through to fields and lookup up from
94+
for k, toField := range to.Fields { // ranging through to fields and lookup up from
8495
elemType, elemBytes, none, ok := derefWithNoneAndOk(fromType, bytes, toField.Name)
8596
if !ok {
8697
// The super value must have all the fields of the subtype cast.
8798
// It's missing a field, so fail.
88-
return false
99+
return super.Value{}, d.errSubtype(typ, bytes, to)
89100
}
90101
if none {
91102
if !toField.Opt {
92103
// A none can't go in a non-optional field.
93-
return false
104+
return super.Value{}, d.errSubtype(typ, bytes, to)
94105
}
95106
nones = append(nones, optOff)
96107
optOff++
108+
} else if toField.Opt && !fromType.Fields[k].Opt {
109+
return super.Value{}, d.errSubtype(typ, bytes, to)
97110
} else {
98111
// We have the value and the to field. Downcast recursively.
99-
if ok := d.downcast(b, elemType, elemBytes, toField.Type); !ok {
100-
return false
112+
val, errVal := d.downcast(elemType, elemBytes, toField.Type)
113+
if errVal != nil {
114+
return super.Value{}, errVal
101115
}
102116
if toField.Opt {
103117
optOff++
104118
}
119+
b.Append(val.Bytes())
105120
}
106121
}
107122
b.EndContainerWithNones(to.Opts, nones)
108-
return true
123+
return super.NewValue(to, b.Bytes().Body()), nil
109124
}
110125

111-
func (d *downcast) toArray(b *scode.Builder, typ super.Type, bytes scode.Bytes, to *super.TypeArray) bool {
126+
func (d *downcast) toArray(typ super.Type, bytes scode.Bytes, to *super.TypeArray) (super.Value, *super.Value) {
112127
if arrayType, ok := typ.(*super.TypeArray); ok {
113-
return d.toContainer(b, arrayType.Type, bytes, to.Type)
128+
return d.toContainer(arrayType.Type, bytes, to, to.Type)
114129
}
115-
return false
130+
return super.Value{}, d.errMismatch(typ, bytes, to)
116131
}
117132

118-
func (d *downcast) toSet(b *scode.Builder, typ super.Type, bytes scode.Bytes, to *super.TypeSet) bool {
133+
func (d *downcast) toSet(typ super.Type, bytes scode.Bytes, to *super.TypeSet) (super.Value, *super.Value) {
119134
if setType, ok := typ.(*super.TypeSet); ok {
120135
// XXX normalize set contents? can reach into body here blah
121-
return d.toContainer(b, setType.Type, bytes, to.Type)
136+
return d.toContainer(setType.Type, bytes, to, to.Type)
122137
}
123-
return false
138+
return super.Value{}, d.errMismatch(typ, bytes, to)
124139
}
125140

126-
func (d *downcast) toContainer(b *scode.Builder, typ super.Type, bytes scode.Bytes, to super.Type) bool {
141+
func (d *downcast) toContainer(elemType super.Type, bytes scode.Bytes, to super.Type, toElem super.Type) (super.Value, *super.Value) {
142+
b := scode.NewBuilder()
127143
b.BeginContainer()
128144
for it := bytes.Iter(); !it.Done(); {
129-
if ok := d.downcast(b, typ, it.Next(), to); !ok {
130-
return false
145+
val, errVal := d.downcast(elemType, it.Next(), toElem)
146+
if errVal != nil {
147+
return super.Value{}, errVal
131148
}
149+
b.Append(val.Bytes())
132150
}
133151
b.EndContainer()
134-
return true
152+
return super.NewValue(to, b.Bytes().Body()), nil
135153
}
136154

137-
func (d *downcast) toMap(b *scode.Builder, typ super.Type, bytes scode.Bytes, to *super.TypeMap) bool {
155+
func (d *downcast) toMap(typ super.Type, bytes scode.Bytes, to *super.TypeMap) (super.Value, *super.Value) {
138156
mapType, ok := typ.(*super.TypeMap)
139157
if !ok {
140-
return false
158+
return super.Value{}, d.errMismatch(typ, bytes, to)
141159
}
160+
b := scode.NewBuilder()
142161
b.BeginContainer()
143162
for it := bytes.Iter(); !it.Done(); {
144-
if ok := d.downcast(b, mapType.KeyType, it.Next(), to.KeyType); !ok {
145-
return false
163+
key, errVal := d.downcast(mapType.KeyType, it.Next(), to.KeyType)
164+
if errVal != nil {
165+
return super.Value{}, errVal
146166
}
147-
if ok := d.downcast(b, mapType.ValType, it.Next(), to.ValType); !ok {
148-
return false
167+
b.Append(key.Bytes())
168+
val, errVal := d.downcast(mapType.ValType, it.Next(), to.ValType)
169+
if errVal != nil {
170+
return super.Value{}, errVal
149171
}
172+
b.Append(val.Bytes())
150173
}
151174
b.EndContainer()
152-
return true
175+
return super.NewValue(to, b.Bytes().Body()), nil
153176
}
154177

155-
func (d *downcast) toUnion(b *scode.Builder, typ super.Type, bytes scode.Bytes, to *super.TypeUnion) bool {
156-
tag := d.subTypeOf(typ, bytes, to.Types)
178+
func (d *downcast) toUnion(typ super.Type, bytes scode.Bytes, to *super.TypeUnion) (super.Value, *super.Value) {
179+
if typ == to {
180+
return super.NewValue(typ, bytes), nil
181+
}
182+
tag, typ, bytes := d.subTypeOf(typ, bytes, to.Types)
157183
if tag < 0 {
158-
return false
184+
if _, ok := typ.(*super.TypeUnion); ok {
185+
typ, bytes = deunion(typ, bytes)
186+
return d.downcast(typ, bytes, to)
187+
}
188+
return super.Value{}, d.errSubtype(typ, bytes, to)
159189
}
160-
super.BeginUnion(b, tag)
161-
if ok := d.downcast(b, typ, bytes, to.Types[tag]); !ok {
162-
return false
190+
val, errVal := d.downcast(typ, bytes, to.Types[tag])
191+
if errVal != nil {
192+
return super.Value{}, errVal
163193
}
194+
b := scode.NewBuilder()
195+
super.BeginUnion(b, tag)
196+
b.Append(val.Bytes())
164197
b.EndContainer()
165-
return true
198+
return super.NewValue(to, b.Bytes().Body()), nil
166199
}
167200

168-
func (d *downcast) toError(b *scode.Builder, typ super.Type, bytes scode.Bytes, to *super.TypeError) bool {
201+
// subTypeOf finds the tag in the union array types that this value should be
202+
// downcast to. If the child value is a fusion value, then the type must match
203+
// the subtype of the fusion value. Otherwise, the child wasn't fused, and by
204+
// definition of a fusion type, one of the union types must exactly match the
205+
// child type.
206+
func (d *downcast) subTypeOf(typ super.Type, bytes scode.Bytes, types []super.Type) (int, super.Type, []byte) {
207+
if fusionType, ok := typ.(*super.TypeFusion); ok {
208+
superBytes, subtype := fusionType.Deref(d.sctx, bytes)
209+
return slices.Index(types, subtype), fusionType.Type, superBytes
210+
}
211+
return slices.Index(types, typ), typ, bytes
212+
}
213+
214+
func (d *downcast) toError(typ super.Type, bytes scode.Bytes, to *super.TypeError) (super.Value, *super.Value) {
169215
if errorType, ok := typ.(*super.TypeError); ok {
170-
return d.downcast(b, errorType.Type, bytes, to.Type)
216+
body, errVal := d.downcast(errorType.Type, bytes, to.Type)
217+
if errVal != nil {
218+
return super.Value{}, errVal
219+
}
220+
return super.NewValue(to, body.Bytes()), nil
171221
}
172-
return false
222+
return super.Value{}, d.errMismatch(typ, bytes, to)
173223
}
174224

175-
func (d *downcast) subTypeOf(typ super.Type, bytes scode.Bytes, types []super.Type) int {
176-
// XXX TBD we should make a subtype() function that returns true if a type is
177-
// a subtype of another and use that here and expose it to the language.
178-
var dummy scode.Builder
179-
for k, t := range types {
180-
if ok := d.downcast(&dummy, typ, bytes, t); ok {
181-
return k
225+
func (d *downcast) toNamed(typ super.Type, bytes scode.Bytes, to *super.TypeNamed) (super.Value, *super.Value) {
226+
if unionType, ok := typ.(*super.TypeUnion); ok {
227+
typ, bytes = deunion(typ, bytes)
228+
// If we are casting a union type to a named, we need to look through the
229+
// union for the named type in question since type fusion fuses named
230+
// types by name. Then when we find the name, we need to form the subtype
231+
// from the union options present.
232+
for _, t := range unionType.Types {
233+
if named, ok := t.(*super.TypeNamed); ok && named.Name == to.Name {
234+
typ, bytes = deunion(typ, bytes)
235+
return super.NewValue(to, bytes), nil
236+
}
182237
}
183-
dummy.Reset()
238+
return super.Value{}, d.errMismatch(typ, bytes, to)
184239
}
185-
return -1
240+
if fromType, ok := typ.(*super.TypeNamed); ok {
241+
if fromType.Name != to.Name {
242+
return super.Value{}, d.errMismatch(typ, bytes, to)
243+
}
244+
val, errVal := d.downcast(fromType.Type, bytes, to.Type)
245+
if errVal != nil {
246+
return super.Value{}, errVal
247+
}
248+
return super.NewValue(to, val.Bytes()), errVal
249+
}
250+
val, errVal := d.downcast(typ, bytes, to.Type)
251+
if errVal != nil {
252+
return super.Value{}, errVal
253+
}
254+
return super.NewValue(to, val.Bytes()), errVal
255+
}
256+
257+
func (d *downcast) errMismatch(typ super.Type, bytes []byte, to super.Type) *super.Value {
258+
return d.sctx.WrapError("downcast: type mismatch to "+sup.FormatType(to), super.NewValue(typ, bytes)).Ptr()
259+
}
260+
261+
func (d *downcast) errSubtype(typ super.Type, bytes []byte, to super.Type) *super.Value {
262+
return d.sctx.WrapError("downcast: invalid subtype "+sup.FormatType(to), super.NewValue(typ, bytes)).Ptr()
186263
}

runtime/sam/expr/function/function.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func New(sctx *super.Context, name string, narg int) (expr.Function, error) {
5858
case "downcast":
5959
argmin = 2
6060
argmax = 2
61-
f = &downcast{sctx}
61+
f = &downcast{sctx: sctx, name: "downcast"}
6262
case "error":
6363
f = &Error{sctx: sctx}
6464
case "fields":

runtime/sam/expr/function/fusion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type fusion struct {
1414
func newFusion(sctx *super.Context) *fusion {
1515
return &fusion{
1616
sctx: sctx,
17-
downcast: NewDowncast(sctx),
17+
downcast: NewDowncast(sctx, "fusion"),
1818
}
1919
}
2020

0 commit comments

Comments
 (0)