|
37 | 37 | import java.util.Arrays; |
38 | 38 | import java.util.IdentityHashMap; |
39 | 39 | import java.util.List; |
| 40 | +import java.util.Optional; |
40 | 41 | import java.util.concurrent.atomic.AtomicReference; |
41 | 42 | import java.util.stream.Collectors; |
42 | 43 | import java.util.stream.Stream; |
|
48 | 49 | import vadl.types.DataType; |
49 | 50 | import vadl.types.MicroArchitectureType; |
50 | 51 | import vadl.types.SIntType; |
51 | | -import vadl.types.StatusType; |
| 52 | +import vadl.types.StructType; |
52 | 53 | import vadl.types.Type; |
53 | 54 | import vadl.types.UIntType; |
54 | 55 | import vadl.utils.BigIntUtils; |
|
65 | 66 | import vadl.viam.Instruction; |
66 | 67 | import vadl.viam.Logic; |
67 | 68 | import vadl.viam.Memory; |
| 69 | +import vadl.viam.Operation; |
68 | 70 | import vadl.viam.Procedure; |
69 | 71 | import vadl.viam.RegisterResource; |
70 | 72 | import vadl.viam.RegisterTensor; |
|
98 | 100 | import vadl.viam.graph.dependency.FuncParamNode; |
99 | 101 | import vadl.viam.graph.dependency.LetNode; |
100 | 102 | import vadl.viam.graph.dependency.MiaBuiltInCall; |
| 103 | +import vadl.viam.graph.dependency.OperationExistsNode; |
| 104 | +import vadl.viam.graph.dependency.OperationForAllNode; |
101 | 105 | import vadl.viam.graph.dependency.ProcCallNode; |
102 | 106 | import vadl.viam.graph.dependency.ReadArtificialResNode; |
103 | 107 | import vadl.viam.graph.dependency.ReadMemNode; |
@@ -1118,6 +1122,36 @@ private ExpressionNode visitIdentifyable(Expr expr) { |
1118 | 1122 | requireNonNull(index.computedTo)); |
1119 | 1123 | } |
1120 | 1124 |
|
| 1125 | + if (computedTarget instanceof ForallThenExpr forallThenExpr) { |
| 1126 | + final var index = forallThenExpr.indices.stream() |
| 1127 | + .filter(idx -> idx.identifier().name.equals(innerName)) |
| 1128 | + .findFirst() |
| 1129 | + .orElseThrow(); |
| 1130 | + final var format = (PseudoFormatType) index.identifier().type(); |
| 1131 | + |
| 1132 | + final List<Operation> ops = format.operations().stream() |
| 1133 | + .map(viamLowering::fetch) |
| 1134 | + .filter(Optional::isPresent).map(Optional::get) |
| 1135 | + .map(Operation.class::cast) |
| 1136 | + .toList(); |
| 1137 | + return new OperationForAllNode.Index(getViamType(format), ops); |
| 1138 | + } |
| 1139 | + |
| 1140 | + if (computedTarget instanceof ExistsInThenExpr existsInThenExpr) { |
| 1141 | + final var index = existsInThenExpr.indices.stream() |
| 1142 | + .filter(idx -> idx.identifier().name.equals(innerName)) |
| 1143 | + .findFirst() |
| 1144 | + .orElseThrow(); |
| 1145 | + final var format = (PseudoFormatType) index.identifier().type(); |
| 1146 | + |
| 1147 | + final List<Operation> ops = format.operations().stream() |
| 1148 | + .map(viamLowering::fetch) |
| 1149 | + .filter(Optional::isPresent).map(Optional::get) |
| 1150 | + .map(Operation.class::cast) |
| 1151 | + .toList(); |
| 1152 | + return new OperationForAllNode.Index(getViamType(format), ops); |
| 1153 | + } |
| 1154 | + |
1121 | 1155 | // Function call without arguments (and no parenthesis) |
1122 | 1156 | if (computedTarget instanceof FunctionDefinition functionDefinition) { |
1123 | 1157 | var function = (Function) viamLowering.fetch(functionDefinition).orElseThrow(); |
@@ -1286,7 +1320,7 @@ private ExpressionNode visitSubCall(CallIndexExpr expr, ExpressionNode exprBefor |
1286 | 1320 | (DataType) getViamType(requireNonNull(subCall.formatFieldType))); |
1287 | 1321 | resultExpr = |
1288 | 1322 | visitSliceIndexCall(slice, subCall.formatFieldType, subCall.argsIndices); |
1289 | | - } else if (exprBeforeSubcall.type() instanceof StatusType) { |
| 1323 | + } else if (exprBeforeSubcall.type() instanceof StructType) { |
1290 | 1324 | var indexing = new StructGetFieldNode(subCall.identifier().name, resultExpr, Type.bool()); |
1291 | 1325 | resultExpr = visitSliceIndexCall(indexing, Type.bool(), subCall.argsIndices); |
1292 | 1326 | } else if (exprBeforeSubcall.type() == MicroArchitectureType.instruction()) { |
@@ -1626,16 +1660,63 @@ public ExpressionNode visit(AsStrExpr expr) { |
1626 | 1660 |
|
1627 | 1661 | @Override |
1628 | 1662 | public ExpressionNode visit(ExistsInExpr expr) { |
1629 | | - throw new RuntimeException( |
1630 | | - "The behavior generator doesn't implement yet: " + expr.getClass().getSimpleName()); |
| 1663 | + |
| 1664 | + final var opDefs = new ArrayList<OperationDefinition>(); |
| 1665 | + final var ops = new ArrayList<Operation>(); |
| 1666 | + |
| 1667 | + for (IsId op : expr.operations) { |
| 1668 | + final var opDef = requireNonNull((OperationDefinition) op.target()); |
| 1669 | + opDefs.add(opDef); |
| 1670 | + viamLowering.fetch(opDef) |
| 1671 | + .ifPresent(o -> ops.add((Operation) o)); |
| 1672 | + } |
| 1673 | + |
| 1674 | + final var idxType = getViamType(PseudoFormatType.of(opDefs)); |
| 1675 | + final var idx = new OperationForAllNode.Index(idxType, ops); |
| 1676 | + |
| 1677 | + var type = getViamType(expr.type()); |
| 1678 | + return new OperationExistsNode(type, idx); |
1631 | 1679 | } |
1632 | 1680 |
|
1633 | 1681 | @Override |
1634 | 1682 | public ExpressionNode visit(ExistsInThenExpr expr) { |
1635 | | - throw new RuntimeException( |
1636 | | - "The behavior generator doesn't implement yet: " + expr.getClass().getSimpleName()); |
| 1683 | + |
| 1684 | + final List<OperationForAllNode.Index> indices = new ArrayList<>(); |
| 1685 | + for (ExistsInThenExpr.Index idx : expr.indices) { |
| 1686 | + final var format = (PseudoFormatType) idx.identifier().type(); |
| 1687 | + final List<Operation> ops = format.operations().stream() |
| 1688 | + .map(viamLowering::fetch) |
| 1689 | + .filter(Optional::isPresent).map(Optional::get) |
| 1690 | + .map(Operation.class::cast) |
| 1691 | + .toList(); |
| 1692 | + final var idxNode = new OperationForAllNode.Index(getViamType(format), ops); |
| 1693 | + indices.add(idxNode); |
| 1694 | + } |
| 1695 | + |
| 1696 | + var body = fetch(expr.thenExpr); |
| 1697 | + var type = getViamType(expr.type()); |
| 1698 | + return new OperationExistsNode(type, indices, body); |
1637 | 1699 | } |
1638 | 1700 |
|
| 1701 | + @Override |
| 1702 | + public ExpressionNode visit(ForallThenExpr expr) { |
| 1703 | + |
| 1704 | + final List<OperationForAllNode.Index> indices = new ArrayList<>(); |
| 1705 | + for (ForallThenExpr.Index idx : expr.indices) { |
| 1706 | + final var format = (PseudoFormatType) idx.identifier().type(); |
| 1707 | + final List<Operation> ops = format.operations().stream() |
| 1708 | + .map(viamLowering::fetch) |
| 1709 | + .filter(Optional::isPresent).map(Optional::get) |
| 1710 | + .map(Operation.class::cast) |
| 1711 | + .toList(); |
| 1712 | + final var idxNode = new OperationForAllNode.Index(getViamType(format), ops); |
| 1713 | + indices.add(idxNode); |
| 1714 | + } |
| 1715 | + |
| 1716 | + var body = fetch(expr.thenExpr); |
| 1717 | + var type = getViamType(expr.type()); |
| 1718 | + return new OperationForAllNode(type, indices, body); |
| 1719 | + } |
1639 | 1720 |
|
1640 | 1721 | @Override |
1641 | 1722 | public ExpressionNode visit(ForallExpr expr) { |
|
0 commit comments