@@ -75,22 +75,40 @@ func Is(err, reference error) bool {
7575 return false
7676 }
7777
78- // Not directly equal. Try harder, using error marks. We don't do
79- // this during the loop above as it may be more expensive.
80- //
81- // Note: there is a more effective recursive algorithm that ensures
82- // that any pair of string only gets compared once. Should the
83- // following code become a performance bottleneck, that algorithm
84- // can be considered instead.
85- refMark := getMark (reference )
86- for c := err ; c != nil ; c = errbase .UnwrapOnce (c ) {
87- if equalMarks (getMark (c ), refMark ) {
78+ return checkMark (err , reference )
79+ }
80+
81+ func checkMark (err , reference error ) bool {
82+ for errNext := err ; errNext != nil ; errNext = errbase .UnwrapOnce (errNext ) {
83+ if isMarkEqual (errNext , reference ) {
8884 return true
8985 }
9086 }
9187 return false
9288}
9389
90+ func isMarkEqual (err , reference error ) bool {
91+ _ , errIsMark := err .(* withMark )
92+ _ , refIsMark := reference .(* withMark )
93+ if errIsMark || refIsMark {
94+ // If either error is a mark, use the more general
95+ // equalMarks() function.
96+ return equalMarks (getMark (err ), getMark (reference ))
97+ }
98+
99+ m1 := err
100+ m2 := reference
101+ for m1 != nil {
102+ if ! errbase .EqualTypeMark (m1 , m2 ) {
103+ return false
104+ }
105+ m1 = errbase .UnwrapOnce (m1 )
106+ m2 = errbase .UnwrapOnce (m2 )
107+ }
108+
109+ return safeGetErrMsg (err ) == safeGetErrMsg (reference )
110+ }
111+
94112func tryDelegateToIsMethod (err , reference error ) bool {
95113 if x , ok := err .(interface { Is (error ) bool }); ok && x .Is (reference ) {
96114 return true
@@ -222,6 +240,8 @@ func equalMarks(m1, m2 errorMark) bool {
222240 return false
223241 }
224242 for i , t := range m1 .types {
243+ // TODO(jeffswenson): I think there is a bug here. What if the chains
244+ // are of different lengths?
225245 if ! t .Equals (m2 .types [i ]) {
226246 return false
227247 }
@@ -234,7 +254,10 @@ func getMark(err error) errorMark {
234254 if m , ok := err .(* withMark ); ok {
235255 return m .mark
236256 }
237- m := errorMark {msg : safeGetErrMsg (err ), types : []errorspb.ErrorTypeMark {errbase .GetTypeMark (err )}}
257+ m := errorMark {
258+ msg : safeGetErrMsg (err ),
259+ types : []errorspb.ErrorTypeMark {errbase .GetTypeMark (err )},
260+ }
238261 for c := errbase .UnwrapOnce (err ); c != nil ; c = errbase .UnwrapOnce (c ) {
239262 m .types = append (m .types , errbase .GetTypeMark (c ))
240263 }
0 commit comments