Skip to content

Commit b29bf9f

Browse files
authored
[mlir][bytecode] Add option to elide locations during serialization (#201183)
Adds a setElideLocations option to BytecodeWriterConfig to elide locations during bytecode serialization. When enabled, all LocationAttrs are mapped to UnknownLoc during numbering and writing to produce location-invariant bytecode (e.g., for stable fingerprinting). Another way to achieve the same thing would be to apply the strip-debuginfo pass, but that requires mutating the module, which in turn requires cloning the module if one still requires the unstripped original. Assisted-by: Antigravity / Gemini
1 parent 70edfe9 commit b29bf9f

5 files changed

Lines changed: 145 additions & 0 deletions

File tree

mlir/include/mlir/Bytecode/BytecodeWriter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ class BytecodeWriterConfig {
177177
attachResourcePrinter(std::move(printer));
178178
}
179179

180+
/// Set a boolean flag to skip emission of unique locations into the bytecode
181+
/// file. When enabled, all locations are mapped to UnknownLoc during
182+
/// numbering.
183+
void setElideLocations(bool shouldElideLocations = true);
184+
bool shouldElideLocations() const;
185+
180186
private:
181187
/// A pointer to allocated storage for the impl state.
182188
std::unique_ptr<Impl> impl;

mlir/lib/Bytecode/Writer/BytecodeWriter.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ struct BytecodeWriterConfig::Impl {
4444
/// file.
4545
bool shouldElideResourceData = false;
4646

47+
/// A flag specifying whether to elide emission of locations.
48+
bool shouldElideLocations = false;
49+
4750
/// A map containing dialect version information for each dialect to emit.
4851
llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap;
4952

@@ -102,6 +105,14 @@ void BytecodeWriterConfig::setElideResourceDataFlag(
102105
impl->shouldElideResourceData = shouldElideResourceData;
103106
}
104107

108+
void BytecodeWriterConfig::setElideLocations(bool shouldElideLocations) {
109+
impl->shouldElideLocations = shouldElideLocations;
110+
}
111+
112+
bool BytecodeWriterConfig::shouldElideLocations() const {
113+
return impl->shouldElideLocations;
114+
}
115+
105116
void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) {
106117
impl->bytecodeVersion = bytecodeVersion;
107118
}

mlir/lib/Bytecode/Writer/IRNumbering.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Bytecode/Encoding.h"
1414
#include "mlir/IR/AsmState.h"
1515
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/Location.h"
1617
#include "mlir/IR/OpDefinition.h"
1718

1819
using namespace mlir;
@@ -198,6 +199,13 @@ IRNumberingState::IRNumberingState(Operation *op,
198199
finalizeDialectResourceNumberings(op);
199200
}
200201

202+
unsigned IRNumberingState::getNumber(Location loc) {
203+
if (config.shouldElideLocations()) {
204+
return getNumber(Attribute(UnknownLoc::get(loc.getContext())));
205+
}
206+
return getNumber(Attribute(loc));
207+
}
208+
201209
void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
202210
// A simple state struct tracking data used when walking operations.
203211
struct StackState {
@@ -308,6 +316,14 @@ void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
308316
});
309317
}
310318

319+
void IRNumberingState::number(Location loc) {
320+
if (config.shouldElideLocations()) {
321+
number(Attribute(UnknownLoc::get(loc.getContext())));
322+
} else {
323+
number(Attribute(loc));
324+
}
325+
}
326+
311327
void IRNumberingState::number(Attribute attr) {
312328
auto it = attrs.try_emplace(attr);
313329
if (!it.second) {

mlir/lib/Bytecode/Writer/IRNumbering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
1515
#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
1616

17+
#include "mlir/IR/Location.h"
1718
#include "mlir/IR/OpImplementation.h"
1819
#include "llvm/ADT/MapVector.h"
1920
#include "llvm/ADT/SetVector.h"
@@ -165,6 +166,7 @@ class IRNumberingState {
165166
assert(attrs.count(attr) && "attribute not numbered");
166167
return attrs[attr]->number;
167168
}
169+
unsigned getNumber(Location loc);
168170
unsigned getNumber(Block *block) {
169171
assert(blockIDs.count(block) && "block not numbered");
170172
return blockIDs[block];
@@ -221,6 +223,7 @@ class IRNumberingState {
221223

222224
/// Number the given IR unit for bytecode emission.
223225
void number(Attribute attr);
226+
void number(Location loc);
224227
void number(Block &block);
225228
DialectNumbering &numberDialect(Dialect *dialect);
226229
DialectNumbering &numberDialect(StringRef dialect);

mlir/unittests/Bytecode/BytecodeTest.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,112 @@ TEST(Bytecode, EmptyFusedLocRoundtrip) {
292292

293293
module.erase();
294294
}
295+
296+
TEST(Bytecode, LocationElision) {
297+
MLIRContext context;
298+
context.allowUnregisteredDialects();
299+
ParserConfig config(&context);
300+
301+
// Module 1: Reuses the same location "foo" everywhere.
302+
StringRef ir1 = R"mlir(
303+
module @Test {
304+
"test.op"() : () -> () loc("foo")
305+
} loc("foo")
306+
)mlir";
307+
308+
// Module 2: Uses unique locations everywhere.
309+
StringRef ir2 = R"mlir(
310+
module @Test {
311+
"test.op"() : () -> () loc("a")
312+
} loc("b")
313+
)mlir";
314+
315+
OwningOpRef<Operation *> op1 = parseSourceString(ir1, config);
316+
OwningOpRef<Operation *> op2 = parseSourceString(ir2, config);
317+
ASSERT_TRUE(op1);
318+
ASSERT_TRUE(op2);
319+
320+
// Serialize both with location elision enabled.
321+
BytecodeWriterConfig writerConfig;
322+
writerConfig.setElideLocations(true);
323+
324+
std::string bytecode1;
325+
{
326+
llvm::raw_string_ostream os(bytecode1);
327+
ASSERT_TRUE(succeeded(writeBytecodeToFile(op1.get(), os, writerConfig)));
328+
}
329+
330+
std::string bytecode2;
331+
{
332+
llvm::raw_string_ostream os(bytecode2);
333+
ASSERT_TRUE(succeeded(writeBytecodeToFile(op2.get(), os, writerConfig)));
334+
}
335+
336+
// If location elision is working correctly, both modules must produce
337+
// the EXACT same bytecode representation, because all locations (shared or
338+
// unique) will have been collapsed into a single shared UnknownLoc.
339+
EXPECT_EQ(bytecode1, bytecode2);
340+
}
341+
342+
TEST(Bytecode, LocationElisionPreservesAttributes) {
343+
MLIRContext context;
344+
context.allowUnregisteredDialects();
345+
ParserConfig config(&context);
346+
347+
// An operation with a debug location ("elide_me") AND a semantic attribute
348+
// that is a LocationAttr ("preserve_me").
349+
StringRef ir = R"mlir(
350+
module @Test {
351+
"test.op"() {some_loc_attr = loc("preserve_me")} : () -> () loc("elide_me")
352+
} loc("elide_me")
353+
)mlir";
354+
355+
OwningOpRef<Operation *> op = parseSourceString(ir, config);
356+
ASSERT_TRUE(op);
357+
358+
// Serialize with location elision enabled.
359+
BytecodeWriterConfig writerConfig;
360+
writerConfig.setElideLocations(true);
361+
362+
std::string bytecode;
363+
{
364+
llvm::raw_string_ostream os(bytecode);
365+
ASSERT_TRUE(succeeded(writeBytecodeToFile(op.get(), os, writerConfig)));
366+
}
367+
368+
// Parse it back using the bytecode reader.
369+
std::unique_ptr<Block> block = std::make_unique<Block>();
370+
ASSERT_TRUE(succeeded(readBytecodeFile(
371+
llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config)));
372+
373+
// Verify we got the roundtripped module.
374+
ASSERT_FALSE(block->empty());
375+
Operation *roundTrippedModule = &block->front();
376+
ASSERT_TRUE(roundTrippedModule);
377+
378+
// Find the inner "test.op" operation.
379+
Operation *innerOp = nullptr;
380+
roundTrippedModule->walk([&](Operation *op) {
381+
if (op->getName().getStringRef() == "test.op") {
382+
innerOp = op;
383+
}
384+
});
385+
ASSERT_TRUE(innerOp);
386+
387+
// 1. Verify that the debug location of "test.op" WAS elided (became
388+
// UnknownLoc).
389+
EXPECT_TRUE(isa<UnknownLoc>(innerOp->getLoc()));
390+
391+
// 2. Verify that the semantic location attribute WAS PRESERVED.
392+
Attribute semanticLocAttr = innerOp->getAttr("some_loc_attr");
393+
ASSERT_TRUE(semanticLocAttr);
394+
auto locAttr = dyn_cast<LocationAttr>(semanticLocAttr);
395+
ASSERT_TRUE(locAttr);
396+
397+
// It should still be loc("preserve_me"), not UnknownLoc.
398+
EXPECT_FALSE(isa<UnknownLoc>(locAttr));
399+
400+
auto nameLoc = dyn_cast<NameLoc>(locAttr);
401+
ASSERT_TRUE(nameLoc);
402+
EXPECT_EQ(nameLoc.getName().getValue(), "preserve_me");
403+
}

0 commit comments

Comments
 (0)