11#include < utility>
22
3+ #include " lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
34#include " lib/Dialect/Secret/IR/SecretDialect.h"
45#include " lib/Dialect/Secret/IR/SecretOps.h"
56#include " lib/Dialect/Secret/IR/SecretTypes.h"
67#include " lib/Transforms/Secretize/Passes.h"
7- #include " llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
8- #include " llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
9- #include " mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
10- #include " mlir/include/mlir/IR/Block.h" // from @llvm-project
11- #include " mlir/include/mlir/IR/Builders.h" // from @llvm-project
12- #include " mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
13- #include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
14- #include " mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
15- #include " mlir/include/mlir/IR/Location.h" // from @llvm-project
16- #include " mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
17- #include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
18- #include " mlir/include/mlir/IR/Types.h" // from @llvm-project
19- #include " mlir/include/mlir/IR/Value.h" // from @llvm-project
20- #include " mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
21- #include " mlir/include/mlir/Support/LLVM.h" // from @llvm-project
22- #include " mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
8+ #include " llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
9+ #include " llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
10+ #include " mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
11+ #include " mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
12+ #include " mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
13+ #include " mlir/include/mlir/IR/Block.h" // from @llvm-project
14+ #include " mlir/include/mlir/IR/Builders.h" // from @llvm-project
15+ #include " mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
16+ #include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
17+ #include " mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
18+ #include " mlir/include/mlir/IR/Location.h" // from @llvm-project
19+ #include " mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
20+ #include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
21+ #include " mlir/include/mlir/IR/Types.h" // from @llvm-project
22+ #include " mlir/include/mlir/IR/Value.h" // from @llvm-project
23+ #include " mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
24+ #include " mlir/include/mlir/Support/LLVM.h" // from @llvm-project
25+ #include " mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
2326#include " mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project
2427
2528namespace mlir {
@@ -29,8 +32,8 @@ namespace heir {
2932#include " lib/Transforms/Secretize/Passes.h.inc"
3033
3134struct WrapWithGeneric : public OpRewritePattern <func::FuncOp> {
32- WrapWithGeneric (mlir::MLIRContext* context)
33- : mlir::OpRewritePattern<func::FuncOp>(context) {}
35+ WrapWithGeneric (mlir::MLIRContext* context, DataFlowSolver* solver )
36+ : mlir::OpRewritePattern<func::FuncOp>(context), solver(solver) {}
3437
3538 LogicalResult matchAndRewrite (func::FuncOp op,
3639 PatternRewriter& rewriter) const override {
@@ -58,54 +61,200 @@ struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
5861 return rewriter.notifyMatchFailure (op, " no secret inputs found" );
5962 }
6063
61- auto newOutputs = llvm::to_vector<6 >(llvm::map_range (
62- op.getResultTypes (),
63- [](Type t) -> Type { return secret::SecretType::get (t); }));
64+ // Externally defined functions have no body - conservatively wrap all
65+ // outputs
66+ if (op.isDeclaration ()) {
67+ SmallVector<Type, 6 > newOutputs;
68+ for (Type resultType : op.getResultTypes ()) {
69+ newOutputs.push_back (secret::SecretType::get (resultType));
70+ }
71+ rewriter.modifyOpInPlace (op, [&] {
72+ op.setFunctionType (
73+ FunctionType::get (getContext (), {newInputs}, {newOutputs}));
74+ });
75+ return success ();
76+ }
77+
78+ // Phase 1: Identify which operations depend on secrets
79+ Block& opEntryBlock = op.getRegion ().front ();
80+ auto * returnOp = opEntryBlock.getTerminator ();
81+
82+ // Track which values are secret (including block arguments)
83+ llvm::DenseSet<Value> secretValues;
84+ for (unsigned i = 0 ; i < op.getNumArguments (); i++) {
85+ if (isSecret (op.getArgument (i), solver)) {
86+ secretValues.insert (op.getArgument (i));
87+ }
88+ }
89+
90+ // Track which operations are secret-dependent
91+ llvm::DenseSet<Operation*> secretOps;
92+ for (Operation& bodyOp : opEntryBlock) {
93+ if (&bodyOp == returnOp) continue ;
6494
65- // modification to function type should go through the rewriter
95+ // An operation is secret if any of its operands are secret
96+ bool isSecretOp = llvm::any_of (bodyOp.getOperands (), [&](Value operand) {
97+ return secretValues.contains (operand) || isSecret (operand, solver);
98+ });
99+
100+ if (isSecretOp) {
101+ secretOps.insert (&bodyOp);
102+ // All results of a secret op become secret
103+ for (Value result : bodyOp.getResults ()) {
104+ secretValues.insert (result);
105+ }
106+ }
107+ }
108+
109+ // Phase 2: Determine output types and which outputs need to be in generic
110+ SmallVector<Type, 6 > newOutputs;
111+ SmallVector<Value> secretReturnValues;
112+ SmallVector<Value> plaintextReturnValues;
113+ SmallVector<unsigned > secretReturnIndices;
114+ SmallVector<unsigned > plaintextReturnIndices;
115+
116+ for (auto [i, resultType] : llvm::enumerate (op.getResultTypes ())) {
117+ Value returnVal = returnOp->getOperand (i);
118+ if (secretValues.contains (returnVal) || isSecret (returnVal, solver)) {
119+ newOutputs.push_back (secret::SecretType::get (resultType));
120+ secretReturnValues.push_back (returnVal);
121+ secretReturnIndices.push_back (i);
122+ } else {
123+ newOutputs.push_back (resultType);
124+ plaintextReturnValues.push_back (returnVal);
125+ plaintextReturnIndices.push_back (i);
126+ }
127+ }
128+
129+ // Modification to function type should go through the rewriter
66130 rewriter.modifyOpInPlace (op, [&] {
67131 op.setFunctionType (
68132 FunctionType::get (getContext (), {newInputs}, {newOutputs}));
69133 });
70134
71- // Externally defined functions have no body
72- if (op.isDeclaration ()) {
135+ // If no operations are secret-dependent, we don't need a generic at all
136+ if (secretOps.empty ()) {
137+ // Just return the plaintexts directly
73138 return success ();
74139 }
75- // Create a new block where we will insert the new secret.generic and move
76- // the function ops into.
77- Block& opEntryBlock = op.getRegion ().front ();
140+
141+ // Phase 3: Collect inputs for the secret.generic block
142+ // These are: (1) secret arguments, (2) plaintext values used by secret ops
143+ SmallVector<Value> genericInputs;
144+ SmallVector<Type> genericInputTypes;
145+
146+ // Add all function arguments that are used by secret ops (or are secret)
147+ for (unsigned i = 0 ; i < op.getNumArguments (); i++) {
148+ genericInputs.push_back (op.getArgument (i));
149+ genericInputTypes.push_back (op.getArgument (i).getType ());
150+ }
151+
152+ // Collect plaintext-defined values that are used inside secret ops
153+ SmallVector<Value> plaintextValuesUsedInGeneric;
154+ for (Operation* secretOp : secretOps) {
155+ for (Value operand : secretOp->getOperands ()) {
156+ // If the operand is from outside the secretOps set (i.e., plaintext)
157+ if (!secretValues.contains (operand)) {
158+ Operation* defOp = operand.getDefiningOp ();
159+ // It's a plaintext value defined by a non-secret op in this function
160+ if (defOp && !secretOps.contains (defOp) &&
161+ defOp->getParentRegion () == &op.getRegion ()) {
162+ if (!llvm::is_contained (plaintextValuesUsedInGeneric, operand)) {
163+ plaintextValuesUsedInGeneric.push_back (operand);
164+ genericInputs.push_back (operand);
165+ genericInputTypes.push_back (operand.getType ());
166+ }
167+ }
168+ }
169+ }
170+ }
171+
172+ // Phase 4: Build the secret.generic with only secret ops
173+ SmallVector<Type> genericOutputTypes;
174+ for (Value v : secretReturnValues) {
175+ genericOutputTypes.push_back (secret::SecretType::get (v.getType ()));
176+ }
177+
178+ // Create a new block for the rewritten function
78179 auto * newBlock = rewriter.createBlock (
79180 &opEntryBlock, opEntryBlock.getArgumentTypes (),
80181 SmallVector<Location>(opEntryBlock.getNumArguments (), op.getLoc ()));
81182
82183 rewriter.setInsertionPointToStart (newBlock);
184+
185+ // Build mapping from old block args to new block args
186+ IRMapping outerMapping;
187+ for (unsigned i = 0 ; i < opEntryBlock.getNumArguments (); ++i) {
188+ outerMapping.map (opEntryBlock.getArgument (i), newBlock->getArgument (i));
189+ }
190+
191+ // Clone plaintext operations to the new block (before the generic)
192+ for (Operation& bodyOp : opEntryBlock) {
193+ if (&bodyOp == returnOp) continue ;
194+ if (!secretOps.contains (&bodyOp)) {
195+ Operation* clonedOp = rewriter.clone (bodyOp, outerMapping);
196+ for (unsigned i = 0 ; i < bodyOp.getNumResults (); ++i) {
197+ outerMapping.map (bodyOp.getResult (i), clonedOp->getResult (i));
198+ }
199+ }
200+ }
201+
202+ // Update genericInputs to use the new block's values
203+ SmallVector<Value> mappedGenericInputs;
204+ for (Value v : genericInputs) {
205+ mappedGenericInputs.push_back (outerMapping.lookupOrDefault (v));
206+ }
207+
208+ // Now create the secret.generic
83209 auto newGeneric = secret::GenericOp::create (
84- rewriter, op.getLoc (), op. getArguments (), newOutputs ,
210+ rewriter, op.getLoc (), mappedGenericInputs, genericOutputTypes ,
85211 [&](OpBuilder& b, Location loc, ValueRange blockArguments) {
86- // Map the input values to the block arguments.
87- IRMapping mp;
88- for (unsigned i = 0 ; i < blockArguments.size (); ++i) {
89- mp.map (opEntryBlock.getArgument (i), blockArguments[i]);
212+ // Map inputs to block arguments
213+ IRMapping innerMapping;
214+ for (unsigned i = 0 ; i < genericInputs.size (); ++i) {
215+ innerMapping.map (genericInputs[i], blockArguments[i]);
216+ }
217+
218+ // Clone only secret operations into the generic
219+ for (Operation& bodyOp : opEntryBlock) {
220+ if (&bodyOp == returnOp) continue ;
221+ if (secretOps.contains (&bodyOp)) {
222+ Operation* clonedOp = b.clone (bodyOp, innerMapping);
223+ for (unsigned i = 0 ; i < bodyOp.getNumResults (); ++i) {
224+ innerMapping.map (bodyOp.getResult (i), clonedOp->getResult (i));
225+ }
226+ }
90227 }
91228
92- auto * returnOp = opEntryBlock.getTerminator ();
93- secret::YieldOp::create (b, loc,
94- llvm::to_vector (llvm::map_range (
95- returnOp->getOperands (), [&](Value v) {
96- return mp.lookupOrDefault (v);
97- })));
98- returnOp->erase ();
229+ // Yield only the secret return values
230+ SmallVector<Value> yieldValues;
231+ for (Value v : secretReturnValues) {
232+ yieldValues.push_back (innerMapping.lookupOrDefault (v));
233+ }
234+ secret::YieldOp::create (b, loc, yieldValues);
99235 });
100236
101- Block& genericBlock = newGeneric.getRegion ().front ();
102- rewriter.inlineBlockBefore (&opEntryBlock,
103- &genericBlock.getOperations ().back (),
104- genericBlock.getArguments ());
105- func::ReturnOp::create (rewriter, op.getLoc (), newGeneric.getResults ());
237+ // Build the final return values in the correct order
238+ SmallVector<Value> finalReturnValues (op.getNumResults ());
239+ unsigned secretResultIdx = 0 ;
240+ for (unsigned idx : secretReturnIndices) {
241+ finalReturnValues[idx] = newGeneric.getResult (secretResultIdx++);
242+ }
243+ for (unsigned idx : plaintextReturnIndices) {
244+ Value returnVal = returnOp->getOperand (idx);
245+ finalReturnValues[idx] = outerMapping.lookupOrDefault (returnVal);
246+ }
247+
248+ func::ReturnOp::create (rewriter, op.getLoc (), finalReturnValues);
249+
250+ // Erase the old block
251+ rewriter.eraseBlock (&opEntryBlock);
106252
107253 return success ();
108254 }
255+
256+ private:
257+ DataFlowSolver* solver;
109258};
110259
111260struct ConvertFuncCall : public OpRewritePattern <func::CallOp> {
@@ -159,20 +308,28 @@ struct WrapGeneric : impl::WrapGenericBase<WrapGeneric> {
159308 using WrapGenericBase::WrapGenericBase;
160309
161310 void detectSecretGeneric () {
162- bool hasSecretGeneric = false ;
163- getOperation ().walk ([&](secret::GenericOp op) { hasSecretGeneric = true ; });
164- if (!hasSecretGeneric) {
165- getOperation ().emitWarning (
166- " No secret found in the module. Did you forget to annotate "
167- " {secret.secret} to the function arguments?" );
168- }
311+ // Note: Since we now correctly handle functions that return only
312+ // plaintext values (which don't get a secret.generic), we should not
313+ // warn about missing secret.generic ops. The warning was intended
314+ // for the case where users forgot to annotate secret inputs, but that
315+ // is already caught by the hasSecrets check in WrapWithGeneric.
169316 }
170317
171318 void runOnOperation () override {
172319 MLIRContext* context = &getContext ();
173320
321+ // Run SecretnessAnalysis to determine which values depend on secrets
322+ DataFlowSolver solver;
323+ dataflow::loadBaselineAnalyses (solver);
324+ solver.load <SecretnessAnalysis>();
325+ if (failed (solver.initializeAndRun (getOperation ()))) {
326+ getOperation ()->emitOpError () << " Failed to run SecretnessAnalysis.\n " ;
327+ signalPassFailure ();
328+ return ;
329+ }
330+
174331 mlir::RewritePatternSet patterns (context);
175- patterns.add <WrapWithGeneric>(context);
332+ patterns.add <WrapWithGeneric>(context, &solver );
176333 (void )walkAndApplyPatterns (getOperation (), std::move (patterns));
177334
178335 // func.call should be converted after callee func type updated
0 commit comments