Skip to content

Commit f631335

Browse files
committed
[mlir][dxsa] Add BinaryWriter to translate from MLIR to DXSA
BinaryWriter translates from an MLIR module in DXSA dialect into a DXSA binary. It is a reverse of what BinaryParser does. Current implementation only supports standard instructions, and needs to be extended to support custom instructions. Instruction table is moved into a separate file (InstrInfo.def), so it can be shared between Parser/Writer, which build different data structures from it. Parser goes from opcodes to mnemonics, and Writer is reversed. Tests are extended to run MLIR in roundtrip to verify both the Parser and Writer. We also compare binary output with input to make sure that we do not lose any data during translation.
1 parent 856b5d3 commit f631335

9 files changed

Lines changed: 706 additions & 318 deletions

File tree

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def DXSA_IndexRelImm : DXSA_Op<"index.rel.imm"> {
128128
TODO
129129
}];
130130

131-
let arguments = (ins DXSA_OperandType:$operand, StrAttr:$op, I64Attr:$imm);
131+
let arguments = (ins DXSA_OperandType:$operand, StrAttr:$op, I32Attr:$imm);
132132
let results = (outs DXSA_IndexType:$index);
133133
let assemblyFormat = "$operand attr-dict";
134134
}

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 2 additions & 314 deletions
Large diffs are not rendered by default.

mlir/lib/Target/DXSA/BinaryWriter.cpp

Lines changed: 352 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,359 @@
1717
using namespace mlir;
1818
using namespace llvm;
1919

20+
using OpcodeMap = llvm::DenseMap<StringRef, uint32_t>;
21+
22+
static void initOpcodeMap(OpcodeMap &opcodes) {
23+
#define SET(OpCode, Name, NumOperands, PrecMask, OpClass) \
24+
opcodes[Name] = OpCode;
25+
#include "InstrInfo.def"
26+
#undef SET
27+
}
28+
29+
static FailureOr<uint32_t> getIndexRepresentation(Operation *op) {
30+
if (auto imm = dyn_cast<dxsa::IndexImm>(op)) {
31+
auto attr = dyn_cast<IntegerAttr>(imm.getImm());
32+
if (!attr) {
33+
emitError(op->getLoc(), "invalid immediate index");
34+
return failure();
35+
}
36+
37+
if (attr.getType().isInteger(32)) {
38+
return D3D10_SB_OPERAND_INDEX_IMMEDIATE32;
39+
}
40+
41+
if (attr.getType().isInteger(64)) {
42+
return D3D10_SB_OPERAND_INDEX_IMMEDIATE64;
43+
}
44+
45+
emitError(op->getLoc(), "invalid immediate index type");
46+
return failure();
47+
}
48+
49+
if (isa<dxsa::IndexRel>(op)) {
50+
return D3D10_SB_OPERAND_INDEX_RELATIVE;
51+
}
52+
53+
if (isa<dxsa::IndexRelImm>(op)) {
54+
return D3D10_SB_OPERAND_INDEX_IMMEDIATE32_PLUS_RELATIVE;
55+
}
56+
57+
emitError(op->getLoc(), "invalid index type");
58+
return failure();
59+
}
60+
61+
class Writer {
62+
public:
63+
Writer(raw_ostream &output) : output(output, endianness::little) {
64+
initOpcodeMap(opcodeMap);
65+
}
66+
67+
LogicalResult emitModule(ModuleOp source) {
68+
Region &region = source.getRegion();
69+
if (!region.hasOneBlock()) {
70+
emitError(region.getLoc(), "region should contain only one block");
71+
return failure();
72+
}
73+
74+
for (auto &op : region.front()) {
75+
if (auto inst = dyn_cast<dxsa::Instruction>(op)) {
76+
if (failed(emitInstruction(inst))) {
77+
return failure();
78+
}
79+
}
80+
}
81+
return success();
82+
}
83+
84+
// Emit an instruction and all its operands recursively.
85+
// FIXME: add extended instructions
86+
LogicalResult emitInstruction(dxsa::Instruction inst) {
87+
// Buffer all tokens for an instruction, so we can fixup
88+
// instruction length before emitting tokens to the output.
89+
buffer.clear();
90+
91+
auto opcodeIt = opcodeMap.find(inst.getMnemonic());
92+
if (opcodeIt == opcodeMap.end()) {
93+
emitError(inst.getLoc(), "unknown mnemonic");
94+
return failure();
95+
}
96+
97+
// First token is an opcode and length. Length is unknown until we
98+
// process all operands.
99+
uint32_t opcode = opcodeIt->second;
100+
uint32_t token = ENCODE_D3D10_SB_OPCODE_TYPE(opcode);
101+
buffer.push_back(token);
102+
103+
for (Value value : inst.getOperands()) {
104+
Operation *op = value.getDefiningOp();
105+
if (!op) {
106+
emitError(value.getLoc(), "undefined operand");
107+
return failure();
108+
}
109+
110+
if (auto operand = dyn_cast<dxsa::Operand>(*op)) {
111+
if (failed(emitOperand(operand))) {
112+
return failure();
113+
}
114+
continue;
115+
}
116+
117+
if (auto operand = dyn_cast<dxsa::OperandImm>(*op)) {
118+
if (failed(emitOperandImm(operand))) {
119+
return failure();
120+
}
121+
continue;
122+
}
123+
124+
emitError(op->getLoc(), "unexpected operand kind");
125+
return failure();
126+
}
127+
128+
// Fixup instruction length after all operands are accumulated in
129+
// the buffer.
130+
buffer[0] |= ENCODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(buffer.size());
131+
for (uint32_t token : buffer) {
132+
output.write(token);
133+
}
134+
135+
return success();
136+
}
137+
138+
// Emit an operand and all its indices recursively.
139+
LogicalResult emitOperand(dxsa::Operand op) {
140+
uint32_t token = ENCODE_D3D10_SB_OPERAND_TYPE(op.getType());
141+
142+
// Encode swizzle, mask, or one component selection.
143+
switch (op.getNumComponents()) {
144+
case 0: {
145+
token |=
146+
ENCODE_D3D10_SB_OPERAND_NUM_COMPONENTS(D3D10_SB_OPERAND_0_COMPONENT);
147+
break;
148+
}
149+
case 1: {
150+
token |=
151+
ENCODE_D3D10_SB_OPERAND_NUM_COMPONENTS(D3D10_SB_OPERAND_1_COMPONENT);
152+
break;
153+
}
154+
case 4: {
155+
token |=
156+
ENCODE_D3D10_SB_OPERAND_NUM_COMPONENTS(D3D10_SB_OPERAND_4_COMPONENT);
157+
if (auto mask = op.getMask()) {
158+
token |= ENCODE_D3D10_SB_OPERAND_4_COMPONENT_SELECTION_MODE(
159+
D3D10_SB_OPERAND_4_COMPONENT_MASK_MODE);
160+
token |= ENCODE_D3D10_SB_OPERAND_4_COMPONENT_MASK(*mask);
161+
} else if (auto swizzle = op.getSwizzle()) {
162+
SmallVector<uint32_t, 4> values;
163+
for (APInt v : *swizzle) {
164+
values.push_back(v.getZExtValue());
165+
}
166+
if (values.size() != 4) {
167+
emitError(op.getLoc(), "invalid number of swizzle values");
168+
return failure();
169+
}
170+
token |= ENCODE_D3D10_SB_OPERAND_4_COMPONENT_SELECTION_MODE(
171+
D3D10_SB_OPERAND_4_COMPONENT_SWIZZLE_MODE);
172+
token |= ENCODE_D3D10_SB_OPERAND_4_COMPONENT_SWIZZLE(
173+
values[0], values[1], values[2], values[3]);
174+
break;
175+
} else if (auto one = op.getOne()) {
176+
token |= ENCODE_D3D10_SB_OPERAND_4_COMPONENT_SELECTION_MODE(
177+
D3D10_SB_OPERAND_4_COMPONENT_SELECT_1_MODE);
178+
token |= ENCODE_D3D10_SB_OPERAND_4_COMPONENT_SELECT_1(*one);
179+
break;
180+
}
181+
break;
182+
}
183+
default: {
184+
emitError(op.getLoc(), "invalid number of components");
185+
return failure();
186+
}
187+
}
188+
189+
// Operand token encodes types and number of indices that follow
190+
// it.
191+
token |= ENCODE_D3D10_SB_OPERAND_INDEX_DIMENSION(op.getNumOperands());
192+
uint32_t dim = 0;
193+
for (Value value : op.getOperands()) {
194+
Operation *index = value.getDefiningOp();
195+
if (!index) {
196+
emitError(value.getLoc(), "index must be defined");
197+
return failure();
198+
}
199+
200+
FailureOr<uint32_t> repr = getIndexRepresentation(index);
201+
if (failed(repr)) {
202+
return failure();
203+
}
204+
token |= ENCODE_D3D10_SB_OPERAND_INDEX_REPRESENTATION(dim, *repr);
205+
dim += 1;
206+
}
207+
208+
buffer.push_back(token);
209+
210+
// Indices follow the operand token.
211+
for (Value value : op.getOperands()) {
212+
Operation *index = value.getDefiningOp();
213+
if (!index) {
214+
emitError(value.getLoc(), "index must be defined");
215+
return failure();
216+
}
217+
218+
if (auto indexImm = dyn_cast<dxsa::IndexImm>(*index)) {
219+
if (failed(emitIndexImm(indexImm))) {
220+
return failure();
221+
}
222+
continue;
223+
}
224+
225+
if (auto indexRel = dyn_cast<dxsa::IndexRel>(*index)) {
226+
if (failed(emitIndexRel(indexRel))) {
227+
return failure();
228+
}
229+
continue;
230+
}
231+
232+
if (auto indexRelImm = dyn_cast<dxsa::IndexRelImm>(*index)) {
233+
if (failed(emitIndexRelImm(indexRelImm))) {
234+
return failure();
235+
}
236+
continue;
237+
}
238+
239+
emitError(value.getLoc(), "invalid index type");
240+
return failure();
241+
}
242+
243+
return success();
244+
}
245+
246+
// Emit an immediate operand. Unlike register operands, immediate
247+
// operands do not have indices. They are encoded as an operand
248+
// followed by N immediate values for each component.
249+
LogicalResult emitOperandImm(dxsa::OperandImm op) {
250+
auto attr = dyn_cast<DenseIntElementsAttr>(op.getImm());
251+
if (!attr) {
252+
emitError(op.getLoc(), "invalid immediate operand");
253+
}
254+
255+
uint32_t token = 0;
256+
257+
Type elementType = attr.getType().getElementType();
258+
if (elementType.isInteger(32)) {
259+
token |= ENCODE_D3D10_SB_OPERAND_TYPE(D3D10_SB_OPERAND_TYPE_IMMEDIATE32);
260+
} else if (elementType.isInteger(64)) {
261+
token |= ENCODE_D3D10_SB_OPERAND_TYPE(D3D10_SB_OPERAND_TYPE_IMMEDIATE64);
262+
} else {
263+
emitError(op.getLoc(), "invalid immediate operand type");
264+
return failure();
265+
}
266+
267+
// Split immediates into tokens. 32 bit immediate values are
268+
// encoded as is, and 64 bit immediates are split into high and
269+
// low 32 bit parts.
270+
SmallVector<uint32_t, 4> values;
271+
for (APInt v : attr) {
272+
uint64_t bits = v.getZExtValue();
273+
if (v.getBitWidth() == 64) {
274+
values.push_back(bits >> 32);
275+
}
276+
values.push_back(bits);
277+
}
278+
279+
if (values.size() == 1) {
280+
token |=
281+
ENCODE_D3D10_SB_OPERAND_NUM_COMPONENTS(D3D10_SB_OPERAND_1_COMPONENT);
282+
} else if (values.size() == 4) {
283+
token |=
284+
ENCODE_D3D10_SB_OPERAND_NUM_COMPONENTS(D3D10_SB_OPERAND_4_COMPONENT);
285+
} else {
286+
emitError(op.getLoc(),
287+
"immediate operand should be either 1- or 4- component");
288+
return failure();
289+
}
290+
291+
buffer.push_back(token);
292+
for (uint32_t v : values) {
293+
buffer.push_back(v);
294+
}
295+
296+
return success();
297+
}
298+
299+
// Emit an immediate index. Its type is encoded into the operand, so
300+
// here we only emit the value as tokens.
301+
LogicalResult emitIndexImm(dxsa::IndexImm op) {
302+
auto attr = dyn_cast<IntegerAttr>(op.getImm());
303+
if (!attr) {
304+
emitError(op.getLoc(), "invalid immediate index");
305+
return failure();
306+
}
307+
308+
uint64_t value = attr.getInt();
309+
if (attr.getType().isInteger(32)) {
310+
buffer.push_back(value);
311+
return success();
312+
}
313+
314+
if (attr.getType().isInteger(64)) {
315+
buffer.push_back(value >> 32);
316+
buffer.push_back(value);
317+
return success();
318+
}
319+
320+
emitError(op.getLoc(), "invalid type of an immediate index");
321+
return failure();
322+
}
323+
324+
// Emit an operand used as an index.
325+
LogicalResult emitIndexRel(dxsa::IndexRel index) {
326+
Operation *def = index.getOperand().getDefiningOp();
327+
if (!def) {
328+
emitError(index.getLoc(), "index must be defined");
329+
return failure();
330+
}
331+
332+
auto operand = dyn_cast<dxsa::Operand>(*def);
333+
if (!operand) {
334+
emitError(def->getLoc(), "invalid index relative operand");
335+
return failure();
336+
}
337+
338+
// Recursively emit an operand, which may also have other indices.
339+
return emitOperand(operand);
340+
}
341+
342+
// Emit an index as an operand + a 32 bit immediate offset.
343+
LogicalResult emitIndexRelImm(dxsa::IndexRelImm index) {
344+
Operation *def = index.getOperand().getDefiningOp();
345+
if (!def) {
346+
emitError(index.getLoc(), "index must be defined");
347+
return failure();
348+
}
349+
350+
auto operand = dyn_cast<dxsa::Operand>(*def);
351+
if (!operand) {
352+
emitError(def->getLoc(), "invalid index relative operand");
353+
return failure();
354+
}
355+
356+
if (failed(emitOperand(operand))) {
357+
return failure();
358+
}
359+
360+
buffer.push_back(index.getImm());
361+
return success();
362+
}
363+
364+
private:
365+
std::vector<uint32_t> buffer;
366+
support::endian::Writer output;
367+
OpcodeMap opcodeMap;
368+
};
369+
20370
namespace mlir::dxsa {
21371
LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output) {
22-
Region &region = source.getRegion();
23-
assert(region.hasOneBlock() && "invalid module");
24-
return failure();
372+
Writer writer(output);
373+
return writer.emitModule(source);
25374
}
26375
} // namespace mlir::dxsa

0 commit comments

Comments
 (0)