Skip to content

Commit 6657041

Browse files
committed
Replace checks for invariants with casts and assertions, add more TypeSwitch
1 parent c4948cd commit 6657041

1 file changed

Lines changed: 31 additions & 66 deletions

File tree

mlir/lib/Target/DXSA/BinaryWriter.cpp

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,24 @@ static void initOpcodeMap(OpcodeMap &opcodes) {
2828
}
2929

3030
static FailureOr<uint32_t> getIndexRepresentation(Operation *op) {
31-
if (auto imm = dyn_cast<dxsa::IndexImm>(op)) {
32-
auto attr = dyn_cast<IntegerAttr>(imm.getImm());
33-
if (!attr) {
34-
return emitError(op->getLoc(), "invalid immediate index");
35-
}
36-
37-
if (attr.getType().isInteger(32)) {
38-
return D3D10_SB_OPERAND_INDEX_IMMEDIATE32;
39-
}
40-
41-
if (attr.getType().isInteger(64)) {
42-
return D3D10_SB_OPERAND_INDEX_IMMEDIATE64;
43-
}
44-
45-
return emitError(op->getLoc(), "invalid immediate index type");
46-
}
47-
48-
if (isa<dxsa::IndexRel>(op)) {
49-
return D3D10_SB_OPERAND_INDEX_RELATIVE;
50-
}
51-
52-
if (isa<dxsa::IndexRelImm>(op)) {
53-
return D3D10_SB_OPERAND_INDEX_IMMEDIATE32_PLUS_RELATIVE;
54-
}
55-
56-
return emitError(op->getLoc(), "invalid index type");
31+
return llvm::TypeSwitch<Operation &, FailureOr<uint32_t>>(*op)
32+
.Case<dxsa::IndexImm>([](auto imm) {
33+
auto attr = cast<IntegerAttr>(imm.getImm());
34+
auto type = cast<IntegerType>(attr.getType());
35+
if (type.getWidth() == 32) {
36+
return D3D10_SB_OPERAND_INDEX_IMMEDIATE32;
37+
}
38+
assert(type.getWidth() == 64 && "invalid index type");
39+
return D3D10_SB_OPERAND_INDEX_IMMEDIATE64;
40+
})
41+
.Case<dxsa::IndexRel>(
42+
[](auto imm) { return D3D10_SB_OPERAND_INDEX_RELATIVE; })
43+
.Case<dxsa::IndexRelImm>([](auto imm) {
44+
return D3D10_SB_OPERAND_INDEX_IMMEDIATE32_PLUS_RELATIVE;
45+
})
46+
.Default([](auto &op) {
47+
return emitError(op.getLoc(), "invalid index type");
48+
});
5749
}
5850

5951
class Writer {
@@ -183,9 +175,7 @@ class Writer {
183175
uint32_t dim = 0;
184176
for (Value value : op.getOperands()) {
185177
Operation *index = value.getDefiningOp();
186-
if (!index) {
187-
return emitError(value.getLoc(), "index must be defined");
188-
}
178+
assert(index && "undefined index");
189179

190180
FailureOr<uint32_t> repr = getIndexRepresentation(index);
191181
if (failed(repr)) {
@@ -200,17 +190,15 @@ class Writer {
200190
// Indices follow the operand token.
201191
for (Value value : op.getOperands()) {
202192
Operation *index = value.getDefiningOp();
203-
if (!index) {
204-
return emitError(value.getLoc(), "index must be defined");
205-
}
193+
assert(index && "undefined index");
206194

207195
auto result = llvm::TypeSwitch<Operation &, LogicalResult>(*index)
208196
.Case<dxsa::IndexImm>(
209-
[this](auto op) { return emitIndexImm(op); })
197+
[this](auto &op) { return emitIndexImm(op); })
210198
.Case<dxsa::IndexRel>(
211-
[this](auto op) { return emitIndexRel(op); })
199+
[this](auto &op) { return emitIndexRel(op); })
212200
.Case<dxsa::IndexRelImm>(
213-
[this](auto op) { return emitIndexRelImm(op); })
201+
[this](auto &op) { return emitIndexRelImm(op); })
214202
.Default([this](auto &op) {
215203
return emitError(op.getLoc(), "invalid index type");
216204
});
@@ -227,20 +215,16 @@ class Writer {
227215
// operands do not have indices. They are encoded as an operand
228216
// followed by N immediate values for each component.
229217
LogicalResult emitOperandImm(dxsa::OperandImm op) {
230-
auto attr = dyn_cast<DenseIntElementsAttr>(op.getImm());
231-
if (!attr) {
232-
return emitError(op.getLoc(), "invalid immediate operand");
233-
}
218+
auto attr = cast<DenseIntElementsAttr>(op.getImm());
234219

235220
uint32_t token = 0;
236221

237-
Type elementType = attr.getType().getElementType();
238-
if (elementType.isInteger(32)) {
222+
auto elementType = cast<IntegerType>(attr.getType().getElementType());
223+
if (elementType.getWidth() == 32) {
239224
token |= ENCODE_D3D10_SB_OPERAND_TYPE(D3D10_SB_OPERAND_TYPE_IMMEDIATE32);
240-
} else if (elementType.isInteger(64)) {
241-
token |= ENCODE_D3D10_SB_OPERAND_TYPE(D3D10_SB_OPERAND_TYPE_IMMEDIATE64);
242225
} else {
243-
return emitError(op.getLoc(), "invalid immediate operand type");
226+
assert(elementType.getWidth() == 64 && "invalid immediate");
227+
token |= ENCODE_D3D10_SB_OPERAND_TYPE(D3D10_SB_OPERAND_TYPE_IMMEDIATE64);
244228
}
245229

246230
// Split immediates into tokens. 32 bit immediate values are
@@ -275,10 +259,7 @@ class Writer {
275259
// Emit an immediate index. Its type is encoded into the operand, so
276260
// here we only emit the value as tokens.
277261
LogicalResult emitIndexImm(dxsa::IndexImm op) {
278-
auto attr = dyn_cast<IntegerAttr>(op.getImm());
279-
if (!attr) {
280-
return emitError(op.getLoc(), "invalid immediate index");
281-
}
262+
auto attr = cast<IntegerAttr>(op.getImm());
282263

283264
uint64_t value = attr.getInt();
284265
if (attr.getType().isInteger(32)) {
@@ -297,31 +278,15 @@ class Writer {
297278

298279
// Emit an operand used as an index.
299280
LogicalResult emitIndexRel(dxsa::IndexRel index) {
300-
Operation *def = index.getOperand().getDefiningOp();
301-
if (!def) {
302-
return emitError(index.getLoc(), "index must be defined");
303-
}
304-
305-
auto operand = dyn_cast<dxsa::Operand>(*def);
306-
if (!operand) {
307-
return emitError(def->getLoc(), "invalid index relative operand");
308-
}
281+
auto operand = cast<dxsa::Operand>(index.getOperand().getDefiningOp());
309282

310283
// Recursively emit an operand, which may also have other indices.
311284
return emitOperand(operand);
312285
}
313286

314287
// Emit an index as an operand + a 32 bit immediate offset.
315288
LogicalResult emitIndexRelImm(dxsa::IndexRelImm index) {
316-
Operation *def = index.getOperand().getDefiningOp();
317-
if (!def) {
318-
return emitError(index.getLoc(), "index must be defined");
319-
}
320-
321-
auto operand = dyn_cast<dxsa::Operand>(*def);
322-
if (!operand) {
323-
return emitError(def->getLoc(), "invalid index relative operand");
324-
}
289+
auto operand = cast<dxsa::Operand>(index.getOperand().getDefiningOp());
325290

326291
if (failed(emitOperand(operand))) {
327292
return failure();

0 commit comments

Comments
 (0)