Skip to content

Commit 7492928

Browse files
committed
fix(Secretize): selective output wrapping in WrapGeneric using SecretnessAnalysis
1 parent bdd554d commit 7492928

3 files changed

Lines changed: 228 additions & 52 deletions

File tree

lib/Transforms/Secretize/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ cc_library(
1717
],
1818
deps = [
1919
":pass_inc_gen",
20+
"@heir//lib/Analysis/SecretnessAnalysis",
2021
"@heir//lib/Dialect/Secret/IR:Dialect",
2122
"@llvm-project//llvm:Support",
23+
"@llvm-project//mlir:Analysis",
2224
"@llvm-project//mlir:FuncDialect",
2325
"@llvm-project//mlir:IR",
2426
"@llvm-project//mlir:Pass",

lib/Transforms/Secretize/WrapGeneric.cpp

Lines changed: 209 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
#include <utility>
22

3+
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
34
#include "lib/Dialect/Secret/IR/SecretDialect.h"
45
#include "lib/Dialect/Secret/IR/SecretOps.h"
56
#include "lib/Dialect/Secret/IR/SecretTypes.h"
67
#include "lib/Transforms/Secretize/Passes.h"
7-
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
8-
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
9-
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
10-
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
11-
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
12-
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
13-
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
14-
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
15-
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
16-
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
17-
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
18-
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
19-
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
20-
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
21-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
22-
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
8+
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
9+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
10+
#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
11+
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
12+
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
13+
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
14+
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
15+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
16+
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
18+
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
19+
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
20+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
21+
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
22+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
23+
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
24+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
25+
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
2326
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project
2427

2528
namespace mlir {
@@ -29,8 +32,8 @@ namespace heir {
2932
#include "lib/Transforms/Secretize/Passes.h.inc"
3033

3134
struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
32-
WrapWithGeneric(mlir::MLIRContext* context)
33-
: mlir::OpRewritePattern<func::FuncOp>(context) {}
35+
WrapWithGeneric(mlir::MLIRContext* context, DataFlowSolver* solver)
36+
: mlir::OpRewritePattern<func::FuncOp>(context), solver(solver) {}
3437

3538
LogicalResult matchAndRewrite(func::FuncOp op,
3639
PatternRewriter& rewriter) const override {
@@ -58,54 +61,200 @@ struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
5861
return rewriter.notifyMatchFailure(op, "no secret inputs found");
5962
}
6063

61-
auto newOutputs = llvm::to_vector<6>(llvm::map_range(
62-
op.getResultTypes(),
63-
[](Type t) -> Type { return secret::SecretType::get(t); }));
64+
// Externally defined functions have no body - conservatively wrap all
65+
// outputs
66+
if (op.isDeclaration()) {
67+
SmallVector<Type, 6> newOutputs;
68+
for (Type resultType : op.getResultTypes()) {
69+
newOutputs.push_back(secret::SecretType::get(resultType));
70+
}
71+
rewriter.modifyOpInPlace(op, [&] {
72+
op.setFunctionType(
73+
FunctionType::get(getContext(), {newInputs}, {newOutputs}));
74+
});
75+
return success();
76+
}
77+
78+
// Phase 1: Identify which operations depend on secrets
79+
Block& opEntryBlock = op.getRegion().front();
80+
auto* returnOp = opEntryBlock.getTerminator();
81+
82+
// Track which values are secret (including block arguments)
83+
llvm::DenseSet<Value> secretValues;
84+
for (unsigned i = 0; i < op.getNumArguments(); i++) {
85+
if (isSecret(op.getArgument(i), solver)) {
86+
secretValues.insert(op.getArgument(i));
87+
}
88+
}
89+
90+
// Track which operations are secret-dependent
91+
llvm::DenseSet<Operation*> secretOps;
92+
for (Operation& bodyOp : opEntryBlock) {
93+
if (&bodyOp == returnOp) continue;
6494

65-
// modification to function type should go through the rewriter
95+
// An operation is secret if any of its operands are secret
96+
bool isSecretOp = llvm::any_of(bodyOp.getOperands(), [&](Value operand) {
97+
return secretValues.contains(operand) || isSecret(operand, solver);
98+
});
99+
100+
if (isSecretOp) {
101+
secretOps.insert(&bodyOp);
102+
// All results of a secret op become secret
103+
for (Value result : bodyOp.getResults()) {
104+
secretValues.insert(result);
105+
}
106+
}
107+
}
108+
109+
// Phase 2: Determine output types and which outputs need to be in generic
110+
SmallVector<Type, 6> newOutputs;
111+
SmallVector<Value> secretReturnValues;
112+
SmallVector<Value> plaintextReturnValues;
113+
SmallVector<unsigned> secretReturnIndices;
114+
SmallVector<unsigned> plaintextReturnIndices;
115+
116+
for (auto [i, resultType] : llvm::enumerate(op.getResultTypes())) {
117+
Value returnVal = returnOp->getOperand(i);
118+
if (secretValues.contains(returnVal) || isSecret(returnVal, solver)) {
119+
newOutputs.push_back(secret::SecretType::get(resultType));
120+
secretReturnValues.push_back(returnVal);
121+
secretReturnIndices.push_back(i);
122+
} else {
123+
newOutputs.push_back(resultType);
124+
plaintextReturnValues.push_back(returnVal);
125+
plaintextReturnIndices.push_back(i);
126+
}
127+
}
128+
129+
// Modification to function type should go through the rewriter
66130
rewriter.modifyOpInPlace(op, [&] {
67131
op.setFunctionType(
68132
FunctionType::get(getContext(), {newInputs}, {newOutputs}));
69133
});
70134

71-
// Externally defined functions have no body
72-
if (op.isDeclaration()) {
135+
// If no operations are secret-dependent, we don't need a generic at all
136+
if (secretOps.empty()) {
137+
// Just return the plaintexts directly
73138
return success();
74139
}
75-
// Create a new block where we will insert the new secret.generic and move
76-
// the function ops into.
77-
Block& opEntryBlock = op.getRegion().front();
140+
141+
// Phase 3: Collect inputs for the secret.generic block
142+
// These are: (1) secret arguments, (2) plaintext values used by secret ops
143+
SmallVector<Value> genericInputs;
144+
SmallVector<Type> genericInputTypes;
145+
146+
// Add all function arguments that are used by secret ops (or are secret)
147+
for (unsigned i = 0; i < op.getNumArguments(); i++) {
148+
genericInputs.push_back(op.getArgument(i));
149+
genericInputTypes.push_back(op.getArgument(i).getType());
150+
}
151+
152+
// Collect plaintext-defined values that are used inside secret ops
153+
SmallVector<Value> plaintextValuesUsedInGeneric;
154+
for (Operation* secretOp : secretOps) {
155+
for (Value operand : secretOp->getOperands()) {
156+
// If the operand is from outside the secretOps set (i.e., plaintext)
157+
if (!secretValues.contains(operand)) {
158+
Operation* defOp = operand.getDefiningOp();
159+
// It's a plaintext value defined by a non-secret op in this function
160+
if (defOp && !secretOps.contains(defOp) &&
161+
defOp->getParentRegion() == &op.getRegion()) {
162+
if (!llvm::is_contained(plaintextValuesUsedInGeneric, operand)) {
163+
plaintextValuesUsedInGeneric.push_back(operand);
164+
genericInputs.push_back(operand);
165+
genericInputTypes.push_back(operand.getType());
166+
}
167+
}
168+
}
169+
}
170+
}
171+
172+
// Phase 4: Build the secret.generic with only secret ops
173+
SmallVector<Type> genericOutputTypes;
174+
for (Value v : secretReturnValues) {
175+
genericOutputTypes.push_back(secret::SecretType::get(v.getType()));
176+
}
177+
178+
// Create a new block for the rewritten function
78179
auto* newBlock = rewriter.createBlock(
79180
&opEntryBlock, opEntryBlock.getArgumentTypes(),
80181
SmallVector<Location>(opEntryBlock.getNumArguments(), op.getLoc()));
81182

82183
rewriter.setInsertionPointToStart(newBlock);
184+
185+
// Build mapping from old block args to new block args
186+
IRMapping outerMapping;
187+
for (unsigned i = 0; i < opEntryBlock.getNumArguments(); ++i) {
188+
outerMapping.map(opEntryBlock.getArgument(i), newBlock->getArgument(i));
189+
}
190+
191+
// Clone plaintext operations to the new block (before the generic)
192+
for (Operation& bodyOp : opEntryBlock) {
193+
if (&bodyOp == returnOp) continue;
194+
if (!secretOps.contains(&bodyOp)) {
195+
Operation* clonedOp = rewriter.clone(bodyOp, outerMapping);
196+
for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) {
197+
outerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i));
198+
}
199+
}
200+
}
201+
202+
// Update genericInputs to use the new block's values
203+
SmallVector<Value> mappedGenericInputs;
204+
for (Value v : genericInputs) {
205+
mappedGenericInputs.push_back(outerMapping.lookupOrDefault(v));
206+
}
207+
208+
// Now create the secret.generic
83209
auto newGeneric = secret::GenericOp::create(
84-
rewriter, op.getLoc(), op.getArguments(), newOutputs,
210+
rewriter, op.getLoc(), mappedGenericInputs, genericOutputTypes,
85211
[&](OpBuilder& b, Location loc, ValueRange blockArguments) {
86-
// Map the input values to the block arguments.
87-
IRMapping mp;
88-
for (unsigned i = 0; i < blockArguments.size(); ++i) {
89-
mp.map(opEntryBlock.getArgument(i), blockArguments[i]);
212+
// Map inputs to block arguments
213+
IRMapping innerMapping;
214+
for (unsigned i = 0; i < genericInputs.size(); ++i) {
215+
innerMapping.map(genericInputs[i], blockArguments[i]);
216+
}
217+
218+
// Clone only secret operations into the generic
219+
for (Operation& bodyOp : opEntryBlock) {
220+
if (&bodyOp == returnOp) continue;
221+
if (secretOps.contains(&bodyOp)) {
222+
Operation* clonedOp = b.clone(bodyOp, innerMapping);
223+
for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) {
224+
innerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i));
225+
}
226+
}
90227
}
91228

92-
auto* returnOp = opEntryBlock.getTerminator();
93-
secret::YieldOp::create(b, loc,
94-
llvm::to_vector(llvm::map_range(
95-
returnOp->getOperands(), [&](Value v) {
96-
return mp.lookupOrDefault(v);
97-
})));
98-
returnOp->erase();
229+
// Yield only the secret return values
230+
SmallVector<Value> yieldValues;
231+
for (Value v : secretReturnValues) {
232+
yieldValues.push_back(innerMapping.lookupOrDefault(v));
233+
}
234+
secret::YieldOp::create(b, loc, yieldValues);
99235
});
100236

101-
Block& genericBlock = newGeneric.getRegion().front();
102-
rewriter.inlineBlockBefore(&opEntryBlock,
103-
&genericBlock.getOperations().back(),
104-
genericBlock.getArguments());
105-
func::ReturnOp::create(rewriter, op.getLoc(), newGeneric.getResults());
237+
// Build the final return values in the correct order
238+
SmallVector<Value> finalReturnValues(op.getNumResults());
239+
unsigned secretResultIdx = 0;
240+
for (unsigned idx : secretReturnIndices) {
241+
finalReturnValues[idx] = newGeneric.getResult(secretResultIdx++);
242+
}
243+
for (unsigned idx : plaintextReturnIndices) {
244+
Value returnVal = returnOp->getOperand(idx);
245+
finalReturnValues[idx] = outerMapping.lookupOrDefault(returnVal);
246+
}
247+
248+
func::ReturnOp::create(rewriter, op.getLoc(), finalReturnValues);
249+
250+
// Erase the old block
251+
rewriter.eraseBlock(&opEntryBlock);
106252

107253
return success();
108254
}
255+
256+
private:
257+
DataFlowSolver* solver;
109258
};
110259

111260
struct ConvertFuncCall : public OpRewritePattern<func::CallOp> {
@@ -159,20 +308,28 @@ struct WrapGeneric : impl::WrapGenericBase<WrapGeneric> {
159308
using WrapGenericBase::WrapGenericBase;
160309

161310
void detectSecretGeneric() {
162-
bool hasSecretGeneric = false;
163-
getOperation().walk([&](secret::GenericOp op) { hasSecretGeneric = true; });
164-
if (!hasSecretGeneric) {
165-
getOperation().emitWarning(
166-
"No secret found in the module. Did you forget to annotate "
167-
"{secret.secret} to the function arguments?");
168-
}
311+
// Note: Since we now correctly handle functions that return only
312+
// plaintext values (which don't get a secret.generic), we should not
313+
// warn about missing secret.generic ops. The warning was intended
314+
// for the case where users forgot to annotate secret inputs, but that
315+
// is already caught by the hasSecrets check in WrapWithGeneric.
169316
}
170317

171318
void runOnOperation() override {
172319
MLIRContext* context = &getContext();
173320

321+
// Run SecretnessAnalysis to determine which values depend on secrets
322+
DataFlowSolver solver;
323+
dataflow::loadBaselineAnalyses(solver);
324+
solver.load<SecretnessAnalysis>();
325+
if (failed(solver.initializeAndRun(getOperation()))) {
326+
getOperation()->emitOpError() << "Failed to run SecretnessAnalysis.\n";
327+
signalPassFailure();
328+
return;
329+
}
330+
174331
mlir::RewritePatternSet patterns(context);
175-
patterns.add<WrapWithGeneric>(context);
332+
patterns.add<WrapWithGeneric>(context, &solver);
176333
(void)walkAndApplyPatterns(getOperation(), std::move(patterns));
177334

178335
// func.call should be converted after callee func type updated

tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,20 @@ module {
8484
return %alloc : memref<1x80xi8>
8585
}
8686
}
87+
88+
// -----
89+
90+
// Regression test for issue #2553: plaintext constant should not become secret
91+
// When a function only returns values that don't depend on secrets,
92+
// no secret.generic should be created.
93+
module {
94+
// CHECK: @plaintext_output(%arg0: !secret.secret<i32>) -> i8
95+
func.func @plaintext_output(%x: i32 {secret.secret}) -> i8 {
96+
// The constant does not depend on the secret input
97+
// CHECK-NOT: secret.generic
98+
%0 = arith.constant 42 : i8
99+
// CHECK: return %{{.*}} : i8
100+
func.return %0 : i8
101+
}
102+
}
103+

0 commit comments

Comments
 (0)