Skip to content

Commit 3ff6ee5

Browse files
committed
errbase: optimize errors.Is
1 parent 5a1487b commit 3ff6ee5

2 files changed

Lines changed: 61 additions & 14 deletions

File tree

errbase/encode.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,32 @@ func GetTypeMark(err error) errorspb.ErrorTypeMark {
305305
return errorspb.ErrorTypeMark{FamilyName: familyName, Extension: extension}
306306
}
307307

308+
// EqualTypeMark checks whether `GetTypeMark(e1).Equals(GetTypeMark(e2))`. It
309+
// is written to be be optimized for the case where neither error has
310+
// serialized type information.
311+
func EqualTypeMark(e1, e2 error) bool {
312+
slowPath := func(err error) bool {
313+
switch err.(type) {
314+
case *opaqueLeaf:
315+
return true
316+
case *opaqueLeafCauses:
317+
return true
318+
case *opaqueWrapper:
319+
return true
320+
case TypeKeyMarker:
321+
return true
322+
}
323+
return false
324+
}
325+
if slowPath(e1) || slowPath(e2) {
326+
return GetTypeMark(e1).Equals(GetTypeMark(e2))
327+
}
328+
329+
t1 := reflect.TypeOf(e1)
330+
t2 := reflect.TypeOf(e2)
331+
return t1.PkgPath() == t2.PkgPath() && t1.String() == t2.String()
332+
}
333+
308334
// RegisterLeafEncoder can be used to register new leaf error types to
309335
// the library. Registered types will be encoded using their own
310336
// Go type when an error is encoded. Wrappers that have not been
@@ -385,9 +411,7 @@ func RegisterWrapperEncoder(theType TypeKey, encoder WrapperEncoder) {
385411
// Note: if the error type has been migrated from a previous location
386412
// or a different type, ensure that RegisterTypeMigration() was called
387413
// prior to RegisterWrapperEncoder().
388-
func RegisterWrapperEncoderWithMessageType(
389-
theType TypeKey, encoder WrapperEncoderWithMessageType,
390-
) {
414+
func RegisterWrapperEncoderWithMessageType(theType TypeKey, encoder WrapperEncoderWithMessageType) {
391415
if encoder == nil {
392416
delete(encoders, theType)
393417
} else {

markers/markers.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,40 @@ func Is(err, reference error) bool {
7575
return false
7676
}
7777

78-
// Not directly equal. Try harder, using error marks. We don't do
79-
// this during the loop above as it may be more expensive.
80-
//
81-
// Note: there is a more effective recursive algorithm that ensures
82-
// that any pair of string only gets compared once. Should the
83-
// following code become a performance bottleneck, that algorithm
84-
// can be considered instead.
85-
refMark := getMark(reference)
86-
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
87-
if equalMarks(getMark(c), refMark) {
78+
return checkMark(err, reference)
79+
}
80+
81+
func checkMark(err, reference error) bool {
82+
for errNext := err; errNext != nil; errNext = errbase.UnwrapOnce(errNext) {
83+
if isMarkEqual(errNext, reference) {
8884
return true
8985
}
9086
}
9187
return false
9288
}
9389

90+
func isMarkEqual(err, reference error) bool {
91+
_, errIsMark := err.(*withMark)
92+
_, refIsMark := reference.(*withMark)
93+
if errIsMark || refIsMark {
94+
// If either error is a mark, use the more general
95+
// equalMarks() function.
96+
return equalMarks(getMark(err), getMark(reference))
97+
}
98+
99+
m1 := err
100+
m2 := reference
101+
for m1 != nil {
102+
if !errbase.EqualTypeMark(m1, m2) {
103+
return false
104+
}
105+
m1 = errbase.UnwrapOnce(m1)
106+
m2 = errbase.UnwrapOnce(m2)
107+
}
108+
109+
return safeGetErrMsg(err) == safeGetErrMsg(reference)
110+
}
111+
94112
func tryDelegateToIsMethod(err, reference error) bool {
95113
if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(reference) {
96114
return true
@@ -222,6 +240,8 @@ func equalMarks(m1, m2 errorMark) bool {
222240
return false
223241
}
224242
for i, t := range m1.types {
243+
// TODO(jeffswenson): I think there is a bug here. What if the chains
244+
// are of different lengths?
225245
if !t.Equals(m2.types[i]) {
226246
return false
227247
}
@@ -234,7 +254,10 @@ func getMark(err error) errorMark {
234254
if m, ok := err.(*withMark); ok {
235255
return m.mark
236256
}
237-
m := errorMark{msg: safeGetErrMsg(err), types: []errorspb.ErrorTypeMark{errbase.GetTypeMark(err)}}
257+
m := errorMark{
258+
msg: safeGetErrMsg(err),
259+
types: []errorspb.ErrorTypeMark{errbase.GetTypeMark(err)},
260+
}
238261
for c := errbase.UnwrapOnce(err); c != nil; c = errbase.UnwrapOnce(c) {
239262
m.types = append(m.types, errbase.GetTypeMark(c))
240263
}

0 commit comments

Comments
 (0)