Skip to content

Commit 991c672

Browse files
authored
chore: Refactor tuple nodes and access (#963)
As discussed in the PR (See #918 and [this thread](#918 (comment))) regarding group annotations, this refactors the current usages of the Tuple nodes to a more general `Struct` which allows field access as well. Tuples are currently used in two places, namely `LetExpr`essions and `LetStatements` as well as `Status` types. For `Let` nodes, the result of the value expression may have a `StructType` with fields `result` and `status`. E.g.: ```vadl let next, status = VADL::adds(PC, 4 as Bits<32>) in if status.zero then PC := next ``` During behavior lowering the name `next` will be mapped to a `StructGetFieldNode`, accessing the field `result` on the expression `VADL::adds(PC, 4 as Bits<32>)`. `StatusType` now exposes fields `negative`, `zero`, `carry` and `overflow`, as opposed to positional indexing.
1 parent 25c4b2b commit 991c672

42 files changed

Lines changed: 594 additions & 446 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

vadl/main/vadl/ast/BehaviorLowering.java

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import vadl.types.DataType;
4949
import vadl.types.MicroArchitectureType;
5050
import vadl.types.SIntType;
51+
import vadl.types.StatusType;
5152
import vadl.types.Type;
5253
import vadl.types.UIntType;
5354
import vadl.utils.BigIntUtils;
@@ -108,9 +109,9 @@
108109
import vadl.viam.graph.dependency.SignExtendNode;
109110
import vadl.viam.graph.dependency.SliceNode;
110111
import vadl.viam.graph.dependency.StageEffectNode;
112+
import vadl.viam.graph.dependency.StructGetFieldNode;
111113
import vadl.viam.graph.dependency.TensorNode;
112114
import vadl.viam.graph.dependency.TruncateNode;
113-
import vadl.viam.graph.dependency.TupleGetFieldNode;
114115
import vadl.viam.graph.dependency.WriteArtificialResNode;
115116
import vadl.viam.graph.dependency.WriteMemNode;
116117
import vadl.viam.graph.dependency.WriteRegTensorNode;
@@ -781,7 +782,6 @@ private SubgraphContext fetch(Statement stmt) {
781782
}
782783

783784

784-
785785
/// This utility function can be used to fill in missing indexes of a tensor.
786786
///
787787
/// It basically returns a permutation of all possible indices for the dimensions provided.
@@ -1067,19 +1067,18 @@ private ExpressionNode visitIdentifyable(Expr expr) {
10671067
// Let statement and expression
10681068
if (computedTarget instanceof LetStatement letStatement) {
10691069
var expression = fetch(letStatement.valueExpr);
1070-
var index = letStatement.getIndexOf(innerName);
10711070
if (letStatement.identifiers.size() > 1) {
1072-
expression = new TupleGetFieldNode(index, expression,
1071+
expression = new StructGetFieldNode(letStatement.mapName(innerName), expression,
10731072
getViamType(letStatement.getTypeOf(innerName)));
10741073
}
10751074
return new LetNode(new LetNode.Name(innerName, letStatement.location()), expression);
10761075
}
10771076
if (computedTarget instanceof LetExpr letExpr) {
10781077
var expression = fetch(letExpr.valueExpr);
1079-
var index = letExpr.getIndexOf(innerName);
10801078
if (letExpr.identifiers.size() > 1) {
10811079
expression =
1082-
new TupleGetFieldNode(index, expression, getViamType(letExpr.getTypeOf(innerName)));
1080+
new StructGetFieldNode(letExpr.mapName(innerName), expression,
1081+
getViamType(letExpr.getTypeOf(innerName)));
10831082
}
10841083
return new LetNode(new LetNode.Name(innerName, letExpr.location()), expression);
10851084
}
@@ -1173,7 +1172,7 @@ public ExpressionNode visit(GroupedExpr expr) {
11731172

11741173
var type = expr.type().equals(Type.string()) ? expr.type() :
11751174
Type.bits(expr.expressions.get(0).type().asDataType()
1176-
.bitWidth() + expr.expressions.get(1).type().asDataType().bitWidth());
1175+
.bitWidth() + expr.expressions.get(1).type().asDataType().bitWidth());
11771176

11781177
var call = new BuiltInCall(concatBuiltin,
11791178
new NodeList<>(expr.expressions.get(0).accept(this),
@@ -1183,7 +1182,7 @@ public ExpressionNode visit(GroupedExpr expr) {
11831182
for (int i = 2; i < expr.expressions.size(); i++) {
11841183
type = expr.type().equals(Type.string()) ? expr.type() :
11851184
Type.bits(type.asDataType().bitWidth()
1186-
+ expr.expressions.get(i).type().asDataType().bitWidth());
1185+
+ expr.expressions.get(i).type().asDataType().bitWidth());
11871186
call = new BuiltInCall(concatBuiltin,
11881187
new NodeList<>(call,
11891188
expr.expressions.get(i).accept(this)),
@@ -1284,9 +1283,8 @@ private ExpressionNode visitSubCall(CallIndexExpr expr, ExpressionNode exprBefor
12841283
(DataType) getViamType(requireNonNull(subCall.formatFieldType)));
12851284
resultExpr =
12861285
visitSliceIndexCall(slice, subCall.formatFieldType, subCall.argsIndices);
1287-
} else if (subCall.computedStatusIndex != null) {
1288-
var indexing =
1289-
new TupleGetFieldNode(subCall.computedStatusIndex, resultExpr, Type.bool());
1286+
} else if (exprBeforeSubcall.type() instanceof StatusType) {
1287+
var indexing = new StructGetFieldNode(subCall.identifier().name, resultExpr, Type.bool());
12901288
resultExpr = visitSliceIndexCall(indexing, Type.bool(), subCall.argsIndices);
12911289
} else if (exprBeforeSubcall.type() == MicroArchitectureType.instruction()) {
12921290
// There is weired way to call functions on instructions
@@ -1413,7 +1411,7 @@ public ExpressionNode visitStageCall(CallIndexExpr expr, StageDefinition stageDe
14131411
.filter(o -> o.identifier().name.equals(subcall.identifier().name))
14141412
.findFirst()
14151413
.get()
1416-
).get();
1414+
).get();
14171415
return new ReadStageOutputNode(output);
14181416
}
14191417

@@ -1811,10 +1809,10 @@ public SubgraphContext visit(AssignmentStatement statement) {
18111809
/**
18121810
* Method that prepares the value so that it can be used for a dynamic write of a resource.
18131811
*
1814-
* @param value value that is being written (right side of assignment)
1815-
* @param entireRead resource value before value is written
1816-
* @param index the dynamic expression of the index.
1817-
* @return that incorporates the written value into the resource.
1812+
* @param value value that is being written (right side of assignment)
1813+
* @param entireRead resource value before value is written
1814+
* @param index the dynamic expression of the index.
1815+
* @return that incorporates the written value into the resource.
18181816
*/
18191817
private ExpressionNode dynamicIndexWriteValue(ExpressionNode value, ReadResourceNode entireRead,
18201818
@Nullable ExpressionNode index) {

vadl/main/vadl/ast/ConstantEvaluator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ private ConstantValue visitIdentifiable(Expr expr) {
199199
}
200200

201201
if (origin instanceof LetExpr letExpr) {
202-
// FIXME: implement tuple unpacking
202+
// FIXME: implement field unpacking
203203
if (letExpr.identifiers.size() > 1) {
204-
throw new EvaluationError("Cannot evaluate tuple unpacking yet",
204+
throw new EvaluationError("Cannot evaluate field unpacking yet",
205205
letExpr.identifiers().getFirst().loc.join(
206206
letExpr.identifiers().getLast().loc));
207207
}

vadl/main/vadl/ast/Expr.java

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import vadl.types.BuiltInTable;
3232
import vadl.types.ConcreteRelationType;
3333
import vadl.types.SIntType;
34-
import vadl.types.TupleType;
34+
import vadl.types.StructType;
3535
import vadl.types.Type;
3636
import vadl.types.UIntType;
3737
import vadl.utils.SourceLocation;
@@ -2030,13 +2030,6 @@ static final class SubCall implements WithLocation {
20302030
@Nullable
20312031
Constant.BitSlice computedBitSlice;
20322032

2033-
/**
2034-
* If the subcall is status access, this field tells which index in the status type the
2035-
* field is.
2036-
*/
2037-
@Nullable
2038-
public Integer computedStatusIndex;
2039-
20402033
SubCall(IdentifierOrPlaceholder id, List<Arguments> argsIndices) {
20412034
this.id = id;
20422035
this.argsIndices = argsIndices;
@@ -2169,16 +2162,34 @@ List<Identifier> identifiers() {
21692162
}
21702163

21712164
/**
2172-
* Returns the index of one of the variables the statement defines.
2165+
* Translates the outer name of the let expression to the inner name of the value expression.
2166+
* E.g.:
21732167
*
2174-
* @return the type of the name provided.
2168+
* <pre>
2169+
* let next, status = VADL::adds(PC, 4 as Bits<32>) in
2170+
* ...
2171+
* </pre>
2172+
* this method will translate "next" to "result" and "status" to "status".
2173+
*
2174+
* @param name the bound name of the let expression.
2175+
* @return the name of the value expression.
21752176
*/
2176-
int getIndexOf(String name) {
2177-
return identifiers().stream().map(i -> i.name).toList().indexOf(name);
2177+
String mapName(String name) {
2178+
var valType = valueExpr.type;
2179+
if (!(valType instanceof StructType struct)) {
2180+
throw new IllegalStateException("Expected StructType but got " + valType);
2181+
}
2182+
final List<String> fields = struct.fieldNames();
2183+
for (var i = 0; i < identifiers.size(); i++) {
2184+
if (name.equals(identifiers().get(i).name)) {
2185+
return fields.get(i);
2186+
}
2187+
}
2188+
throw new IllegalStateException("Let expression does not have a name `%s`.".formatted(name));
21782189
}
21792190

21802191
/**
2181-
* Returns the type of one of the variables the statement defines.
2192+
* Returns the type of one of the variables the expression defines.
21822193
*
21832194
* @return the type of the name provided.
21842195
*/
@@ -2188,11 +2199,11 @@ Type getTypeOf(String name) {
21882199
return requireNonNull(valType);
21892200
}
21902201

2191-
if (!(valType instanceof TupleType valTuple)) {
2192-
throw new IllegalStateException("Expected TupleType but got " + valType);
2202+
if (!(valType instanceof StructType valStruct)) {
2203+
throw new IllegalStateException("Expected StructType but got " + valType);
21932204
}
21942205

2195-
return requireNonNull(valTuple.get(getIndexOf(name)));
2206+
return requireNonNull(valStruct.get(mapName(name)));
21962207
}
21972208

21982209
@Override

vadl/main/vadl/ast/Statement.java

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import java.util.function.Consumer;
2424
import javax.annotation.Nullable;
2525
import vadl.javaannotations.ast.Child;
26-
import vadl.types.TupleType;
26+
import vadl.types.StructType;
2727
import vadl.types.Type;
2828
import vadl.utils.SourceLocation;
2929
import vadl.utils.WithLocation;
@@ -124,7 +124,7 @@ <R> R accept(StatementVisitor<R> visitor) {
124124
}
125125

126126
/**
127-
* If multiple identifiers are provided, they are used to unpack a tuple.
127+
* If multiple identifiers are provided, they are used to unpack fields of a struct.
128128
*/
129129
final class LetStatement extends Statement {
130130
List<IsId> identifiers;
@@ -147,18 +147,31 @@ List<Identifier> identifiers() {
147147
}
148148

149149
/**
150-
* Returns the index of one of the variables the statement defines.
150+
* Translates the outer name of the let statement to the inner name of the value expression.
151+
* E.g.:
151152
*
152-
* @return the index/offset of the name provided.
153+
* <pre>
154+
* let next, status = VADL::adds(PC, 4 as Bits<32>) in
155+
* ...
156+
* </pre>
157+
* this method will translate "next" to "result" and "status" to "status".
158+
*
159+
* @param name the bound name of the let statement.
160+
* @return the name of the value expression.
153161
*/
154-
int getIndexOf(String name) {
155-
for (var idx = 0; idx < identifiers.size(); idx++) {
156-
if (((Identifier) identifiers.get(idx)).name.equals(name)) {
157-
return idx;
158-
}
162+
String mapName(String name) {
163+
var valType = valueExpr.type;
164+
if (!(valType instanceof StructType struct)) {
165+
throw new IllegalStateException("Expected StructType but got " + valType);
159166
}
160167

161-
return -1;
168+
final List<String> fields = struct.fieldNames();
169+
for (var i = 0; i < identifiers.size(); i++) {
170+
if (name.equals(identifiers().get(i).name)) {
171+
return fields.get(i);
172+
}
173+
}
174+
throw new IllegalStateException("Let statement does not have a name `%s`.".formatted(name));
162175
}
163176

164177
/**
@@ -172,11 +185,11 @@ Type getTypeOf(String name) {
172185
return Objects.requireNonNull(valType);
173186
}
174187

175-
if (!(valType instanceof TupleType valTuple)) {
176-
throw new IllegalStateException("Expected TupleType but got " + valType);
188+
if (!(valType instanceof StructType valStruct)) {
189+
throw new IllegalStateException("Expected StructType but got " + valType);
177190
}
178191

179-
return Objects.requireNonNull(valTuple.get(getIndexOf(name)));
192+
return Objects.requireNonNull(valStruct.fields().get(mapName(name)));
180193
}
181194

182195
@Override

vadl/main/vadl/ast/TypeChecker.java

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
import vadl.types.SIntType;
5656
import vadl.types.StatusType;
5757
import vadl.types.StringType;
58-
import vadl.types.TupleType;
58+
import vadl.types.StructType;
5959
import vadl.types.Type;
6060
import vadl.types.UIntType;
6161
import vadl.types.asmTypes.AsmType;
@@ -3762,7 +3762,6 @@ private void visitSubCall(CallIndexExpr expr, Type typeBeforeSubCall) {
37623762
.build());
37633763
}
37643764
var fieldType = Type.bool();
3765-
subCall.computedStatusIndex = allowedStatusfields.indexOf(fieldName);
37663765
visitSliceIndexCall(expr, fieldType, subCall.argsIndices);
37673766
type = expr.type;
37683767
} else if (type instanceof InstructionType) {
@@ -4085,23 +4084,24 @@ public Void visit(LetExpr expr) {
40854084
var valType = check(expr.valueExpr);
40864085

40874086
if (expr.identifiers.size() > 1) {
4088-
if (!(valType instanceof TupleType valTupleType)) {
4089-
var loc = expr.identifiers().get(0).loc.join(expr.valueExpr.location());
4087+
if (!(valType instanceof StructType valStructType)) {
4088+
var loc = expr.identifiers().getFirst().loc.join(expr.valueExpr.location());
40904089
throw addErrorAndStopChecking(error("Type Mismatch", loc)
4091-
.description("Tuple unpacking only works on tuples but the type was `%s`", valType)
4090+
.description("Field unpacking only works on structs but the type was `%s`", valType)
40924091
.build());
40934092
}
40944093

4095-
if (expr.identifiers.size() != valTupleType.size()) {
4096-
var loc = expr.identifiers().get(0).loc.join(expr.valueExpr.location());
4097-
addErrorAndStopChecking(error("Invalid Tuple Unpacking", loc)
4098-
.description("Cannot unpack %d values form a `%s`.", expr.identifiers.size(),
4094+
if (expr.identifiers.size() != valStructType.size()) {
4095+
var loc = expr.identifiers().getFirst().loc.join(expr.valueExpr.location());
4096+
throw addErrorAndStopChecking(error("Invalid Field Unpacking", loc)
4097+
.description("Cannot unpack %d values from a `%s`.", expr.identifiers.size(),
40994098
valType)
41004099
.build());
41014100
}
41024101

4102+
var valTypes = valStructType.types();
41034103
for (int i = 0; i < expr.identifiers.size(); i++) {
4104-
expr.identifiers().get(i).type = valTupleType.get(i);
4104+
expr.identifiers().get(i).type = valTypes.get(i);
41054105
}
41064106
} else {
41074107
expr.identifiers().getFirst().type = valType;
@@ -4365,23 +4365,24 @@ public Void visit(LetStatement statement) {
43654365
var valType = check(statement.valueExpr);
43664366

43674367
if (statement.identifiers.size() > 1) {
4368-
if (!(valType instanceof TupleType valTupleType)) {
4369-
var loc = statement.identifiers().get(0).loc.join(statement.valueExpr.location());
4368+
if (!(valType instanceof StructType valStructType)) {
4369+
var loc = statement.identifiers().getFirst().loc.join(statement.valueExpr.location());
43704370
throw addErrorAndStopChecking(error("Type Mismatch", loc)
4371-
.description("Tuple unpacking only works on tuples but the type was `%s`", valType)
4371+
.description("Field unpacking only works on structs but the type was `%s`", valType)
43724372
.build());
43734373
}
43744374

4375-
if (statement.identifiers.size() != valTupleType.size()) {
4376-
var loc = statement.identifiers().get(0).loc.join(statement.valueExpr.location());
4377-
addErrorAndStopChecking(error("Invalid Tuple Unpacking", loc)
4378-
.description("Cannot unpack %d values form a `%s`.", statement.identifiers.size(),
4375+
if (statement.identifiers.size() != valStructType.size()) {
4376+
var loc = statement.identifiers().getFirst().loc.join(statement.valueExpr.location());
4377+
throw addErrorAndStopChecking(error("Invalid Field Unpacking", loc)
4378+
.description("Cannot unpack %d values from a `%s`.", statement.identifiers.size(),
43794379
valType)
43804380
.build());
43814381
}
43824382

4383+
var valTypes = valStructType.types();
43834384
for (int i = 0; i < statement.identifiers.size(); i++) {
4384-
statement.identifiers().get(i).type = valTupleType.get(i);
4385+
statement.identifiers().get(i).type = valTypes.get(i);
43854386
}
43864387
} else {
43874388
statement.identifiers().getFirst().type = valType;

vadl/main/vadl/cppCodeGen/mixins/CDefaultMixins.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656
import vadl.viam.graph.dependency.SelectNode;
5757
import vadl.viam.graph.dependency.SignExtendNode;
5858
import vadl.viam.graph.dependency.SliceNode;
59+
import vadl.viam.graph.dependency.StructGetFieldNode;
5960
import vadl.viam.graph.dependency.TensorNode;
6061
import vadl.viam.graph.dependency.TruncateNode;
61-
import vadl.viam.graph.dependency.TupleGetFieldNode;
6262
import vadl.viam.graph.dependency.ZeroExtendNode;
6363
import vadl.viam.passes.sideEffectScheduling.nodes.InstrExitNode;
6464

@@ -328,7 +328,7 @@ interface AllDependencies extends AllExpressions {
328328
@SuppressWarnings("MissingJavadocType")
329329
interface AllExpressions
330330
extends TypeCasts, Constant, FuncCall, BuiltIns, Slice, LetNode, Select, FuncParam, ForallIdx,
331-
TupleAccess, Label {
331+
FieldAccess, Label {
332332

333333
}
334334

@@ -449,10 +449,10 @@ default void handle(CGenContext<Node> ctx, SelectNode toHandle) {
449449
}
450450

451451
@SuppressWarnings("MissingJavadocType")
452-
interface TupleAccess {
452+
interface FieldAccess {
453453
@Handler
454-
default void handle(CGenContext<Node> ctx, TupleGetFieldNode toHandle) {
455-
throw new UnsupportedOperationException("Type TupleGetFieldNode not yet implemented");
454+
default void handle(CGenContext<Node> ctx, StructGetFieldNode toHandle) {
455+
throw new UnsupportedOperationException("Type StructGetFieldNode not yet implemented");
456456
}
457457
}
458458

0 commit comments

Comments
 (0)