Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/Transforms/Secretize/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ cc_library(
],
deps = [
":pass_inc_gen",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
107 changes: 75 additions & 32 deletions lib/Transforms/Secretize/WrapGeneric.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
#include <utility>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Secret/IR/SecretDialect.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Transforms/Secretize/Passes.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project

namespace mlir {
Expand All @@ -29,8 +31,8 @@ namespace heir {
#include "lib/Transforms/Secretize/Passes.h.inc"

struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
WrapWithGeneric(mlir::MLIRContext* context)
: mlir::OpRewritePattern<func::FuncOp>(context) {}
WrapWithGeneric(mlir::MLIRContext* context, DataFlowSolver* solver)
: mlir::OpRewritePattern<func::FuncOp>(context), solver(solver) {}

LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter& rewriter) const override {
Expand Down Expand Up @@ -58,23 +60,52 @@ struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
return rewriter.notifyMatchFailure(op, "no secret inputs found");
}

auto newOutputs = llvm::to_vector<6>(llvm::map_range(
op.getResultTypes(),
[](Type t) -> Type { return secret::SecretType::get(t); }));
// Externally defined functions have no body - conservatively wrap all
// outputs as secret
if (op.isDeclaration()) {
auto newOutputs = llvm::to_vector<6>(llvm::map_range(
op.getResultTypes(),
[](Type t) -> Type { return secret::SecretType::get(t); }));
rewriter.modifyOpInPlace(op, [&] {
op.setFunctionType(
FunctionType::get(getContext(), {newInputs}, {newOutputs}));
});
return success();
}

// Use SecretnessAnalysis to determine which outputs depend on secrets
Block& opEntryBlock = op.getRegion().front();
auto* returnOp = opEntryBlock.getTerminator();

// Determine output types: only wrap in secret if the value depends on
// secrets
SmallVector<Type, 6> newOutputs;
bool hasSecretOutputs = false;
for (auto [i, resultType] : llvm::enumerate(op.getResultTypes())) {
Value returnVal = returnOp->getOperand(i);
if (isSecret(returnVal, solver)) {
newOutputs.push_back(secret::SecretType::get(resultType));
hasSecretOutputs = true;
} else {
newOutputs.push_back(resultType);
}
}

// modification to function type should go through the rewriter
// Modification to function type should go through the rewriter
rewriter.modifyOpInPlace(op, [&] {
op.setFunctionType(
FunctionType::get(getContext(), {newInputs}, {newOutputs}));
});

// Externally defined functions have no body
if (op.isDeclaration()) {
// If no outputs depend on secrets, don't create a generic block.
// This fixes issue #2553: functions that return only plaintext values
// should not have their outputs wrapped in secret types.
if (!hasSecretOutputs) {
return success();
}

// Create a new block where we will insert the new secret.generic and move
// the function ops into.
Block& opEntryBlock = op.getRegion().front();
auto* newBlock = rewriter.createBlock(
&opEntryBlock, opEntryBlock.getArgumentTypes(),
SmallVector<Location>(opEntryBlock.getNumArguments(), op.getLoc()));
Expand All @@ -89,7 +120,7 @@ struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
mp.map(opEntryBlock.getArgument(i), blockArguments[i]);
}

auto* returnOp = opEntryBlock.getTerminator();
// Yield the return values, mapped through the IR mapping
secret::YieldOp::create(b, loc,
llvm::to_vector(llvm::map_range(
returnOp->getOperands(), [&](Value v) {
Expand All @@ -106,6 +137,9 @@ struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {

return success();
}

private:
DataFlowSolver* solver;
};

struct ConvertFuncCall : public OpRewritePattern<func::CallOp> {
Expand Down Expand Up @@ -161,18 +195,27 @@ struct WrapGeneric : impl::WrapGenericBase<WrapGeneric> {
void detectSecretGeneric() {
bool hasSecretGeneric = false;
getOperation().walk([&](secret::GenericOp op) { hasSecretGeneric = true; });
if (!hasSecretGeneric) {
getOperation().emitWarning(
"No secret found in the module. Did you forget to annotate "
"{secret.secret} to the function arguments?");
}
// Note: We no longer warn if no secret.generic is found, because
// functions that return only plaintext values intentionally don't
// create a secret.generic block. The hasSecrets check in WrapWithGeneric
// already catches the case where users forget to annotate secret inputs.
}

void runOnOperation() override {
MLIRContext* context = &getContext();

// Run SecretnessAnalysis to determine which values depend on secrets
DataFlowSolver solver;
dataflow::loadBaselineAnalyses(solver);
solver.load<SecretnessAnalysis>();
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run SecretnessAnalysis.\n";
signalPassFailure();
return;
}

mlir::RewritePatternSet patterns(context);
patterns.add<WrapWithGeneric>(context);
patterns.add<WrapWithGeneric>(context, &solver);
(void)walkAndApplyPatterns(getOperation(), std::move(patterns));

// func.call should be converted after callee func type updated
Expand Down
17 changes: 17 additions & 0 deletions tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,20 @@ module {
return %alloc : memref<1x80xi8>
}
}

// -----

// Regression test for issue #2553: plaintext constant should not become secret
// When a function only returns values that don't depend on secrets,
// no secret.generic should be created.
module {
// CHECK: @plaintext_output(%arg0: !secret.secret<i32>) -> i8
func.func @plaintext_output(%x: i32 {secret.secret}) -> i8 {
// The constant does not depend on the secret input
// CHECK-NOT: secret.generic
%0 = arith.constant 42 : i8
// CHECK: return %{{.*}} : i8
func.return %0 : i8
}
}