Skip to content

Commit da03148

Browse files
authored
[mlir][bytecode] Unpack i1 splats to 0x01 (llvm#186221)
Previously the arith folder test would emit `dense<255>` (`0xFF` zero extended). In-memory without bytecode is `0x01`, so this change ensures in-memory formats match. Also changes `0xFF` to `~0x00` since compilation on machines with signed chars was causing issues, this should ensure it is set to all ones regardless of char interpretation: ``` [1083/5044] Building CXX object tools/mlir/lib/IR/CMakeFiles/obj.MLIRIR.dir/BuiltinDialectBytecode.cpp.o /.../BuiltinDialectBytecode.cpp:184:35: warning: result of comparison of constant 255 with expression of type 'const char' is always false [-Wtautological-constant-out-of-range-compare] 184 | if (blob.size() == 1 && blob[0] == 0xFF) { | ~~~~~~~ ^ ~~~~ 1 warning generated. ``` Fixes llvm#186178
1 parent 8e105b3 commit da03148

2 files changed

Lines changed: 26 additions & 2 deletions

File tree

mlir/lib/IR/BuiltinDialectBytecode.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,18 @@ readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader, ShapedType type,
179179
// cheap.
180180
size_t numElements = type.getNumElements();
181181
size_t packedSize = llvm::divideCeil(numElements, 8);
182+
183+
// Unpack splats to single element 0x01 to match unpacked splat format.
184+
if (blob.size() == 1 && blob[0] == ~0x00) {
185+
rawData.resize(1);
186+
rawData[0] = 0x01;
187+
return success();
188+
}
189+
190+
// Unpack the blob if it's packed.
191+
// Splat and blob.size() == packedSize for all N<=8 elements are ambiguous,
192+
// non 0xFF means not splat so must be unpacked.
182193
if (blob.size() == packedSize && blob.size() != numElements) {
183-
// Unpack the blob.
184194
rawData.resize(numElements);
185195
for (size_t i = 0; i < numElements; ++i)
186196
rawData[i] = (blob[i / 8] & (1 << (i % 8))) ? 1 : 0;
@@ -200,9 +210,11 @@ static void writeDenseIntOrFPElementsAttr(DialectBytecodeWriter &writer,
200210
ArrayRef<char> rawData = attr.getRawData();
201211

202212
// If the attribute is a splat, we can just splat the value directly.
213+
// Use 0xFF to avoid ambiguity with packed format of <=8 elements,
214+
// written ~0x00 to ensure proper compilation with signed chars.
203215
if (attr.isSplat()) {
204216
data.resize(1);
205-
data[0] = rawData[0] ? 0xFF : 0x00;
217+
data[0] = rawData[0] ? ~0x00 : 0x00;
206218
writer.writeUnownedBlob(data);
207219
return;
208220
}

mlir/test/Bytecode/i1_roundtrip.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// RUN: mlir-opt %s -emit-bytecode | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s -canonicalize | FileCheck %s --check-prefix=CHECK-FOLD
3+
// RUN: mlir-opt %s -emit-bytecode | mlir-opt -canonicalize | FileCheck %s --check-prefix=CHECK-FOLD
24

35
// CHECK-LABEL: func.func @test_i1_splat_true
46
func.func @test_i1_splat_true() -> tensor<100xi1> {
@@ -43,3 +45,13 @@ func.func @test_i9_mixed() {
4345
%0 = arith.constant dense<[true, false, true, false, true, false, true, false, true]> : tensor<9xi1>
4446
return
4547
}
48+
49+
// Test that the in-memory representation of i1 values is correctly handled
50+
// during bytecode roundtrip (must be unpacked to 0x01 not 0xFF).
51+
// See llvm/llvm-project#186178.
52+
func.func public @test_in_memory_repr() -> (tensor<32xi32> {jax.result_info = "result"}) {
53+
// CHECK-FOLD: dense<1> : tensor<32xi32>
54+
%cst = arith.constant dense<true> : tensor<32xi1>
55+
%0 = arith.extui %cst : tensor<32xi1> to tensor<32xi32>
56+
return %0 : tensor<32xi32>
57+
}

0 commit comments

Comments
 (0)