Skip to content

Commit 9cfff0c

Browse files
authored
Propagate errors in expr evaluation (#489)
Previously it stopped error values at param boundaries, but it would still crash on an outer expr from type mismatch. This adds a priority error propagation to prevent compiler crashes.
1 parent cbe46eb commit 9cfff0c

4 files changed

Lines changed: 84 additions & 9 deletions

File tree

compiler/src/main/scala/edg/compiler/ExprEvaluate.scala

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ object ExprEvaluate {
1515

1616
def evalBinary(binary: expr.BinaryExpr, lhs: ExprValue, rhs: ExprValue): ExprValue = {
1717
import expr.BinaryExpr.Op
18+
19+
ErrorValue.aggregate(Seq(lhs, rhs)) match {
20+
case Some(errorValue) => return errorValue // errors propagation takes priority
21+
case None => () // continue with normal evaluation
22+
}
23+
1824
binary.op match {
1925
// Note promotion rules: range takes precedence, then float, then int
2026
case Op.ADD => (lhs, rhs) match {
@@ -223,6 +229,12 @@ object ExprEvaluate {
223229

224230
def evalBinarySet(binarySet: expr.BinarySetExpr, lhs: ExprValue, rhs: ExprValue): ExprValue = {
225231
import expr.BinarySetExpr.Op
232+
233+
ErrorValue.aggregate(Seq(lhs, rhs)) match {
234+
case Some(errorValue) => return errorValue // errors propagation takes priority
235+
case None => () // continue with normal evaluation
236+
}
237+
226238
binarySet.op match {
227239
// Note promotion rules: range takes precedence, then float, then int
228240
// TODO: can we deduplicate these cases to delegate them to evalBinary?
@@ -272,6 +284,12 @@ object ExprEvaluate {
272284

273285
def evalUnary(unary: expr.UnaryExpr, `val`: ExprValue): ExprValue = {
274286
import expr.UnaryExpr.Op
287+
288+
ErrorValue.aggregate(Seq(`val`)) match {
289+
case Some(errorValue) => return errorValue // errors propagation takes priority
290+
case None => () // continue with normal evaluation
291+
}
292+
275293
(unary.op, `val`) match {
276294
case (Op.NEGATE, `val`) => `val` match {
277295
case RangeValue(valMin, valMax) =>
@@ -307,6 +325,13 @@ object ExprEvaluate {
307325

308326
def evalUnarySet(unarySet: expr.UnarySetExpr, vals: ExprValue, emptyValue: ExprValue): ExprValue = {
309327
import expr.UnarySetExpr.Op
328+
329+
// note this does not short circuit out emptyValue if vals is not empty
330+
ErrorValue.aggregate(Seq(vals, emptyValue)) match {
331+
case Some(errorValue) => return errorValue // errors propagation takes priority
332+
case None => () // continue with normal evaluation
333+
}
334+
310335
(unarySet.op, vals) match {
311336
case (_, ArrayValue.Empty(_)) => emptyValue
312337
// In this case we don't do numeric promotion
@@ -402,29 +427,42 @@ object ExprEvaluate {
402427

403428
def evalStruct(struct: expr.StructExpr, vals: Map[String, ExprValue]): ExprValue = ???
404429

405-
def evalRange(range: expr.RangeExpr, minimum: ExprValue, maximum: ExprValue): ExprValue = (minimum, maximum) match {
406-
case (FloatPromotable(lhs), FloatPromotable(rhs)) => if (lhs <= rhs) {
407-
RangeValue(lhs, rhs)
408-
} else {
409-
ErrorValue(Some(s"range($minimum, $maximum) is malformed, $minimum > $maximum"))
410-
}
411-
case _ => throw new ExprEvaluateException(s"Unknown range operands types $minimum $maximum from $range")
430+
def evalRange(range: expr.RangeExpr, minimum: ExprValue, maximum: ExprValue): ExprValue = {
431+
ErrorValue.aggregate(Seq(minimum, maximum)) match {
432+
case Some(errorValue) => return errorValue // errors propagation takes priority
433+
case None => () // continue with normal evaluation
434+
}
435+
436+
(minimum, maximum) match {
437+
case (FloatPromotable(lhs), FloatPromotable(rhs)) => if (lhs <= rhs) {
438+
RangeValue(lhs, rhs)
439+
} else {
440+
ErrorValue(Some(s"range($minimum, $maximum) is malformed, $minimum > $maximum"))
441+
}
442+
case _ => throw new ExprEvaluateException(s"Unknown range operands types $minimum $maximum from $range")
443+
}
412444
}
413445

414446
def evalIfThenElse(ite: expr.IfThenElseExpr, cond: ExprValue, tru: ExprValue, fal: ExprValue): ExprValue =
415447
cond match {
448+
case ErrorValue(_) => cond
416449
case BooleanValue(true) => tru
417450
case BooleanValue(false) => fal
418451
case _ => throw new ExprEvaluateException(s"Unknown condition types if $cond then $tru else $fal from $ite")
419452
}
420453

421-
def evalExtract(extract: expr.ExtractExpr, container: ExprValue, index: ExprValue): ExprValue =
454+
def evalExtract(extract: expr.ExtractExpr, container: ExprValue, index: ExprValue): ExprValue = {
455+
ErrorValue.aggregate(Seq(container, index)) match {
456+
case Some(errorValue) => return errorValue // errors propagation takes priority
457+
case None => () // continue with normal evaluation
458+
}
422459
(container, index) match {
423460
case (ArrayValue(container), IntValue(index)) => container(index.toInt)
424461
case _ => throw new ExprEvaluateException(
425462
s"Unknown operand types for extract element $index from $container from $extract"
426463
)
427464
}
465+
}
428466
}
429467

430468
class ExprEvaluate(refs: ConstProp, root: DesignPath) extends ValueExprMap[ExprValue] {

compiler/src/main/scala/edg/compiler/ExprValue.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ object ExprValue {
5151
}
5252
}
5353

54+
object ErrorValue {
55+
def aggregate(exprs: Seq[ExprValue]): Option[ErrorValue] = {
56+
val errors = exprs.collect { case error: ErrorValue => error }
57+
errors match {
58+
case Seq() => None
59+
case Seq(single) => Some(single)
60+
case multiple => Some(aggregateErrors(multiple))
61+
}
62+
}
63+
64+
private def aggregateErrors(errors: Seq[ErrorValue]): ErrorValue = {
65+
val msgs = errors.flatMap(_.msg)
66+
if (msgs.isEmpty) {
67+
ErrorValue(None)
68+
} else {
69+
ErrorValue(Some(msgs.mkString("; ")))
70+
}
71+
}
72+
}
73+
5474
case class ErrorValue(msg: Option[String]) extends ExprValue {
5575
def toLit: lit.ValueLit = Literal.Error(msg)
5676
def toStringValue: String = s"error(${msg.getOrElse("")})"

compiler/src/test/scala/edg/compiler/ConstPropAssignTest.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class ConstPropAssignTest extends AnyFlatSpec {
188188
assert(constProp.getAllSolved(IndirectDesignPath() + "a").isInstanceOf[ErrorValue])
189189
}
190190

191-
it should "not propagate generated ErrorValues" in {
191+
it should "not propagate ErrorValue-valued parameters" in {
192192
import edgir.expr.expr.BinaryExpr.Op
193193
val constProp = new ConstProp()
194194
constProp.addDeclaration(DesignPath() + "x", ValInit.Range)
@@ -208,4 +208,21 @@ class ConstPropAssignTest extends AnyFlatSpec {
208208
assert(constProp.getValue(IndirectDesignPath() + "a").get.isInstanceOf[ErrorValue])
209209
constProp.getValue(IndirectDesignPath() + "b") shouldBe None
210210
}
211+
212+
it should "propagate ErrorValues" in {
213+
import edgir.expr.expr.BinaryExpr.Op
214+
val constProp = new ConstProp()
215+
constProp.addDeclaration(DesignPath() + "a", ValInit.Range)
216+
constProp.addAssignExpr(
217+
IndirectDesignPath() + "a",
218+
ValueExpr.BinOp(
219+
Op.ADD,
220+
ValueExpr.BinOp(Op.SHRINK_MULT, ValueExpr.Literal(1, 1), ValueExpr.Literal(0, 2)),
221+
ValueExpr.Literal(0, 0)
222+
),
223+
)
224+
225+
constProp.getErrors should not be empty
226+
assert(constProp.getValue(IndirectDesignPath() + "a").get.isInstanceOf[ErrorValue])
227+
}
211228
}
2.66 KB
Binary file not shown.

0 commit comments

Comments
 (0)