Skip to content

Commit fcb93ab

Browse files
authored
frontend: Group expressions and annotations (#918)
- add forall and exists expressions
1 parent fafdbdd commit fcb93ab

48 files changed

Lines changed: 2058 additions & 64 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

vadl/main/vadl/ast/AnnotationTable.java

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import vadl.viam.Encoding;
5757
import vadl.viam.Endianness;
5858
import vadl.viam.Format;
59+
import vadl.viam.Group;
5960
import vadl.viam.Instruction;
6061
import vadl.viam.Memory;
6162
import vadl.viam.MemoryRegion;
@@ -66,9 +67,11 @@
6667
import vadl.viam.annotations.AsmGenerateRulesAnno;
6768
import vadl.viam.annotations.AsmParserCaseSensitive;
6869
import vadl.viam.annotations.AsmParserCommentString;
70+
import vadl.viam.annotations.AssertAnnotation;
6971
import vadl.viam.annotations.DefineOperandAnnotation;
7072
import vadl.viam.annotations.EnableHtifAnno;
7173
import vadl.viam.annotations.InstructionUndefinedAnno;
74+
import vadl.viam.annotations.StopAnnotation;
7275
import vadl.viam.annotations.TbStateRegisterAnnotation;
7376

7477
/**
@@ -236,7 +239,7 @@ public class AnnotationTable {
236239
.build();
237240

238241
groupOn(MemoryDefinition.class)
239-
.add("big endian", OptExprAnnotation::new)
242+
.add("big endian", OptExprAnnotation::new)
240243
.add("little endian", OptExprAnnotation::new)
241244
.check(ctx -> {
242245
ctx.verifyOnlyOneOfGroup();
@@ -264,6 +267,28 @@ public class AnnotationTable {
264267
.ifPresent(ann -> apply.accept(ann, Endianness.LITTLE));
265268
}).build();
266269

270+
annotationOn(GroupDefinition.class, "assert", () -> new ExprAnnotation(true))
271+
.check((def, annotation, lowering) -> annotation.verifyExprType(Type.bool()))
272+
.applyViam((def, annotation, lowering) -> {
273+
var group = (Group) def;
274+
var graph = new BehaviorLowering(lowering)
275+
.getFunctionGraph(annotation.expr, "Assert " + group.simpleName());
276+
graph.setParentDefinition(group);
277+
group.addAnnotation(new AssertAnnotation(graph));
278+
})
279+
.build();
280+
281+
annotationOn(GroupDefinition.class, "stop", () -> new ExprAnnotation(true))
282+
.check((def, annotation, lowering) -> annotation.verifyExprType(Type.bool()))
283+
.applyViam((def, annotation, lowering) -> {
284+
var group = (Group) def;
285+
var graph = new BehaviorLowering(lowering)
286+
.getFunctionGraph(annotation.expr, "Stop " + group.simpleName());
287+
graph.setParentDefinition(group);
288+
group.addAnnotation(new StopAnnotation(graph));
289+
})
290+
.build();
291+
267292
/// PROCESSOR RELATED ///
268293

269294
annotationOn(ProcessorDefinition.class, "htif", EnableAnnotation::new)
@@ -1003,9 +1028,15 @@ abstract class Annotation implements AnnotationDeclaration, WithLocation {
10031028
@LazyInit
10041029
AnnotationDefinition definition;
10051030

1031+
protected boolean allowMultiple;
1032+
10061033
public Annotation() {
10071034
}
10081035

1036+
public Annotation(boolean allowMultiple) {
1037+
this.allowMultiple = allowMultiple;
1038+
}
1039+
10091040
/**
10101041
* Called by the symbol resolver to resolve parts of the annotation.
10111042
*
@@ -1058,6 +1089,10 @@ protected void verifyValuesNonEmpty(AnnotationDefinition definition) {
10581089
.build();
10591090
}
10601091
}
1092+
1093+
protected boolean allowMultiple() {
1094+
return false;
1095+
}
10611096
}
10621097

10631098
// ---------- GENERAL ANNOTATION CLASSES ----------
@@ -1134,7 +1169,8 @@ class FormatFieldAnnotation extends Annotation {
11341169
Constant.BitSlice slice;
11351170

11361171
@Override
1137-
void resolveName(AnnotationDefinition definition, SymbolTable.SymbolResolver resolver) { }
1172+
void resolveName(AnnotationDefinition definition, SymbolTable.SymbolResolver resolver) {
1173+
}
11381174

11391175
@Override
11401176
void typeCheck(AnnotationDefinition definition, TypeChecker typeChecker) {
@@ -1518,6 +1554,10 @@ public ExprAnnotation() {
15181554
super();
15191555
}
15201556

1557+
public ExprAnnotation(boolean allowMultiple) {
1558+
super(allowMultiple);
1559+
}
1560+
15211561
@Override
15221562
void resolveName(AnnotationDefinition definition, SymbolTable.SymbolResolver resolver) {
15231563
verifyValuesCnt(definition, 1);

vadl/main/vadl/ast/AstUtils.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,37 @@ static BuiltInTable.BuiltIn getOperatorBuiltIn(Operator operator, List<Type> arg
8686
return switch (builtIns.size()) {
8787
case 0 -> throw new IllegalStateException(
8888
"Couldn't get any matching builtin for %s".formatted(operator));
89-
case 1 -> builtIns.get(0);
89+
case 1 -> builtIns.getFirst();
9090
case 2 -> {
91-
var isSigned = argTypes.getFirst().getClass() == SIntType.class;
91+
92+
93+
final var firstArgType = argTypes.getFirst().getClass();
94+
if (firstArgType == PseudoFormatType.class) {
95+
// For opequ/opneq, we select the overload only upon exact match
96+
builtIns = builtIns.stream()
97+
.filter(b -> b.signature().argTypeClasses().getFirst() == PseudoFormatType.class)
98+
.toList();
99+
} else {
100+
builtIns = builtIns.stream()
101+
.filter(b -> b.signature().argTypeClasses().getFirst() != PseudoFormatType.class)
102+
.toList();
103+
}
104+
105+
if (builtIns.size() == 1) {
106+
yield builtIns.getFirst();
107+
}
108+
109+
final var isSigned = firstArgType == SIntType.class;
92110
builtIns = builtIns.stream()
93-
.filter(b -> (b.signature().argTypeClasses().get(0) == SIntType.class) == isSigned)
111+
.filter(
112+
b -> (b.signature().argTypeClasses().getFirst() == SIntType.class) == isSigned)
94113
.toList();
114+
95115
if (builtIns.size() != 1) {
96116
throw new IllegalStateException("Couldn't find a builtin function");
97117
}
98-
yield builtIns.get(0);
118+
119+
yield builtIns.getFirst();
99120
}
100121
case 3 -> {
101122
int numSigned = argTypes.get(0).getClass() == SIntType.class ? 1 : 0;

vadl/main/vadl/ast/AstVisitor.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,14 @@ public Void visit(ExistsInThenExpr expr) {
778778
return null;
779779
}
780780

781+
@Override
782+
public Void visit(ForallThenExpr expr) {
783+
beforeTravel(expr);
784+
expr.children().forEach(this::travel);
785+
afterTravel(expr);
786+
return null;
787+
}
788+
781789
@Override
782790
public Void visit(ForallExpr expr) {
783791
beforeTravel(expr);

vadl/main/vadl/ast/BehaviorLowering.java

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.util.Arrays;
3838
import java.util.IdentityHashMap;
3939
import java.util.List;
40+
import java.util.Optional;
4041
import java.util.concurrent.atomic.AtomicReference;
4142
import java.util.stream.Collectors;
4243
import java.util.stream.Stream;
@@ -48,7 +49,7 @@
4849
import vadl.types.DataType;
4950
import vadl.types.MicroArchitectureType;
5051
import vadl.types.SIntType;
51-
import vadl.types.StatusType;
52+
import vadl.types.StructType;
5253
import vadl.types.Type;
5354
import vadl.types.UIntType;
5455
import vadl.utils.BigIntUtils;
@@ -65,6 +66,7 @@
6566
import vadl.viam.Instruction;
6667
import vadl.viam.Logic;
6768
import vadl.viam.Memory;
69+
import vadl.viam.Operation;
6870
import vadl.viam.Procedure;
6971
import vadl.viam.RegisterResource;
7072
import vadl.viam.RegisterTensor;
@@ -98,6 +100,8 @@
98100
import vadl.viam.graph.dependency.FuncParamNode;
99101
import vadl.viam.graph.dependency.LetNode;
100102
import vadl.viam.graph.dependency.MiaBuiltInCall;
103+
import vadl.viam.graph.dependency.OperationExistsNode;
104+
import vadl.viam.graph.dependency.OperationForAllNode;
101105
import vadl.viam.graph.dependency.ProcCallNode;
102106
import vadl.viam.graph.dependency.ReadArtificialResNode;
103107
import vadl.viam.graph.dependency.ReadMemNode;
@@ -1118,6 +1122,36 @@ private ExpressionNode visitIdentifyable(Expr expr) {
11181122
requireNonNull(index.computedTo));
11191123
}
11201124

1125+
if (computedTarget instanceof ForallThenExpr forallThenExpr) {
1126+
final var index = forallThenExpr.indices.stream()
1127+
.filter(idx -> idx.identifier().name.equals(innerName))
1128+
.findFirst()
1129+
.orElseThrow();
1130+
final var format = (PseudoFormatType) index.identifier().type();
1131+
1132+
final List<Operation> ops = format.operations().stream()
1133+
.map(viamLowering::fetch)
1134+
.filter(Optional::isPresent).map(Optional::get)
1135+
.map(Operation.class::cast)
1136+
.toList();
1137+
return new OperationForAllNode.Index(getViamType(format), ops);
1138+
}
1139+
1140+
if (computedTarget instanceof ExistsInThenExpr existsInThenExpr) {
1141+
final var index = existsInThenExpr.indices.stream()
1142+
.filter(idx -> idx.identifier().name.equals(innerName))
1143+
.findFirst()
1144+
.orElseThrow();
1145+
final var format = (PseudoFormatType) index.identifier().type();
1146+
1147+
final List<Operation> ops = format.operations().stream()
1148+
.map(viamLowering::fetch)
1149+
.filter(Optional::isPresent).map(Optional::get)
1150+
.map(Operation.class::cast)
1151+
.toList();
1152+
return new OperationForAllNode.Index(getViamType(format), ops);
1153+
}
1154+
11211155
// Function call without arguments (and no parenthesis)
11221156
if (computedTarget instanceof FunctionDefinition functionDefinition) {
11231157
var function = (Function) viamLowering.fetch(functionDefinition).orElseThrow();
@@ -1286,7 +1320,7 @@ private ExpressionNode visitSubCall(CallIndexExpr expr, ExpressionNode exprBefor
12861320
(DataType) getViamType(requireNonNull(subCall.formatFieldType)));
12871321
resultExpr =
12881322
visitSliceIndexCall(slice, subCall.formatFieldType, subCall.argsIndices);
1289-
} else if (exprBeforeSubcall.type() instanceof StatusType) {
1323+
} else if (exprBeforeSubcall.type() instanceof StructType) {
12901324
var indexing = new StructGetFieldNode(subCall.identifier().name, resultExpr, Type.bool());
12911325
resultExpr = visitSliceIndexCall(indexing, Type.bool(), subCall.argsIndices);
12921326
} else if (exprBeforeSubcall.type() == MicroArchitectureType.instruction()) {
@@ -1626,16 +1660,63 @@ public ExpressionNode visit(AsStrExpr expr) {
16261660

16271661
@Override
16281662
public ExpressionNode visit(ExistsInExpr expr) {
1629-
throw new RuntimeException(
1630-
"The behavior generator doesn't implement yet: " + expr.getClass().getSimpleName());
1663+
1664+
final var opDefs = new ArrayList<OperationDefinition>();
1665+
final var ops = new ArrayList<Operation>();
1666+
1667+
for (IsId op : expr.operations) {
1668+
final var opDef = requireNonNull((OperationDefinition) op.target());
1669+
opDefs.add(opDef);
1670+
viamLowering.fetch(opDef)
1671+
.ifPresent(o -> ops.add((Operation) o));
1672+
}
1673+
1674+
final var idxType = getViamType(PseudoFormatType.of(opDefs));
1675+
final var idx = new OperationForAllNode.Index(idxType, ops);
1676+
1677+
var type = getViamType(expr.type());
1678+
return new OperationExistsNode(type, idx);
16311679
}
16321680

16331681
@Override
16341682
public ExpressionNode visit(ExistsInThenExpr expr) {
1635-
throw new RuntimeException(
1636-
"The behavior generator doesn't implement yet: " + expr.getClass().getSimpleName());
1683+
1684+
final List<OperationForAllNode.Index> indices = new ArrayList<>();
1685+
for (ExistsInThenExpr.Index idx : expr.indices) {
1686+
final var format = (PseudoFormatType) idx.identifier().type();
1687+
final List<Operation> ops = format.operations().stream()
1688+
.map(viamLowering::fetch)
1689+
.filter(Optional::isPresent).map(Optional::get)
1690+
.map(Operation.class::cast)
1691+
.toList();
1692+
final var idxNode = new OperationForAllNode.Index(getViamType(format), ops);
1693+
indices.add(idxNode);
1694+
}
1695+
1696+
var body = fetch(expr.thenExpr);
1697+
var type = getViamType(expr.type());
1698+
return new OperationExistsNode(type, indices, body);
16371699
}
16381700

1701+
@Override
1702+
public ExpressionNode visit(ForallThenExpr expr) {
1703+
1704+
final List<OperationForAllNode.Index> indices = new ArrayList<>();
1705+
for (ForallThenExpr.Index idx : expr.indices) {
1706+
final var format = (PseudoFormatType) idx.identifier().type();
1707+
final List<Operation> ops = format.operations().stream()
1708+
.map(viamLowering::fetch)
1709+
.filter(Optional::isPresent).map(Optional::get)
1710+
.map(Operation.class::cast)
1711+
.toList();
1712+
final var idxNode = new OperationForAllNode.Index(getViamType(format), ops);
1713+
indices.add(idxNode);
1714+
}
1715+
1716+
var body = fetch(expr.thenExpr);
1717+
var type = getViamType(expr.type());
1718+
return new OperationForAllNode(type, indices, body);
1719+
}
16391720

16401721
@Override
16411722
public ExpressionNode visit(ForallExpr expr) {

vadl/main/vadl/ast/ConstantEvaluator.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,13 @@ public ConstantValue visit(ExistsInThenExpr expr) {
559559

560560
}
561561

562+
@Override
563+
public ConstantValue visit(ForallThenExpr expr) {
564+
throw new EvaluationError(
565+
"The constant evaluator cannot evaluate a %s.".formatted(expr.nodeName()),
566+
expr);
567+
}
568+
562569

563570
@Override
564571
public ConstantValue visit(ForallExpr expr) {

0 commit comments

Comments
 (0)