Skip to content

Commit 1d4155f

Browse files
Merge pull request #2653 from VedantParanjape:permutation
PiperOrigin-RevId: 871950476
2 parents d4a089d + a1df9b9 commit 1d4155f

7 files changed

Lines changed: 142 additions & 1 deletion

File tree

lib/Dialect/TensorExt/IR/TensorExtOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,24 @@ LogicalResult verifyLayoutMatchesType(const Attribute& layoutAttr, Type type,
129129
return success();
130130
}
131131

132+
if (auto denseElementsAttr = dyn_cast<DenseIntElementsAttr>(layoutAttr)) {
133+
// Assert the attr has shape <N x 4>
134+
int64_t rank = denseElementsAttr.getType().getRank();
135+
if (rank != 2)
136+
return op->emitOpError()
137+
<< "requires permutation attribute to Rank 2, but "
138+
<< "found shape <" << denseElementsAttr.getType() << ">";
139+
140+
int64_t cols = denseElementsAttr.getType().getDimSize(1);
141+
if (cols != 4)
142+
return op->emitOpError()
143+
<< "requires permutation attribute to be of shape <N x 4>, but "
144+
"found shape <"
145+
<< denseElementsAttr.getType() << ">" << "Rank: " << rank
146+
<< " Cols: " << cols << "\n";
147+
return success();
148+
}
149+
132150
return op->emitOpError("Unsupported layout attribute");
133151
}
134152

lib/Dialect/TensorExt/IR/TensorExtOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def TensorExt_RemapOp : TensorExt_Op<"remap", [Pure, AllTypesMatch<["input", "ou
9292
// Forces ops to use a general Attribute and dyn_cast to the specific kind of
9393
// layout they support.
9494
def LayoutLike : AnyAttrOf<[
95+
// A list of tuples [a, b, c, d] representing an explicit map (ct, slot) ->
96+
// (ct, slot) defined by f(a, b) = (c, d).
97+
AnyI64ElementsAttr,
9598
TensorExt_LayoutAttr,
9699
]>;
97100

lib/Transforms/ConvertToCiphertextSemantics/AssignLayout.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,51 @@ static FailureOr<Value> implementUnpackOpNew(
195195
return loop.value().getResults()[0];
196196
}
197197

198+
static FailureOr<Value> implementAssignLayoutPermutation(
199+
Value input, DenseIntElementsAttr permutation, int64_t ciphertextSize,
200+
ImplicitLocOpBuilder& builder,
201+
const std::function<void(Operation*)>& createdOpCallback) {
202+
auto elementType = getElementTypeOrSelf(input.getType());
203+
204+
// TODO(#2666): This logic assumes ctxt = 0, this can be extended to handle a
205+
// more general case. permutation is <N x 4 x i64>: each row is [src_ct,
206+
// src_slot, dst_ct, dst_slot]. src_ct and dst_ct are always 0 (single
207+
// ciphertext), so treat as 1D: extract from input[0][src_slot] and insert
208+
// into result[0][dst_slot].
209+
auto ciphertextType = RankedTensorType::get({1, ciphertextSize}, elementType);
210+
auto zeroCtxt = arith::ConstantOp::create(
211+
builder, ciphertextType, builder.getZeroAttr(ciphertextType));
212+
createdOpCallback(zeroCtxt);
213+
Value result = zeroCtxt.getResult();
214+
auto ctIdxOp = arith::ConstantIndexOp::create(builder, 0);
215+
createdOpCallback(ctIdxOp);
216+
Value ctIdx = ctIdxOp.getResult();
217+
218+
for (auto it = permutation.value_begin<APInt>();
219+
it != permutation.value_end<APInt>();) {
220+
// skip src_ct (always 0)
221+
++it;
222+
int64_t srcSlot = (*it++).getSExtValue();
223+
// skip dst_ct (always 0)
224+
++it;
225+
int64_t dstSlot = (*it++).getSExtValue();
226+
227+
auto srcSlotIdx = arith::ConstantIndexOp::create(builder, srcSlot);
228+
createdOpCallback(srcSlotIdx);
229+
auto extracted = tensor::ExtractOp::create(builder, input,
230+
ValueRange{ctIdx, srcSlotIdx});
231+
createdOpCallback(extracted);
232+
auto dstSlotIdx = arith::ConstantIndexOp::create(builder, dstSlot);
233+
createdOpCallback(dstSlotIdx);
234+
auto insertOp = tensor::InsertOp::create(
235+
builder, extracted.getResult(), result, ValueRange{ctIdx, dstSlotIdx});
236+
createdOpCallback(insertOp);
237+
result = insertOp.getResult();
238+
}
239+
240+
return result;
241+
}
242+
198243
FailureOr<Value> implementAssignLayout(
199244
Value input, Attribute layout, int64_t ciphertextSize,
200245
ImplicitLocOpBuilder& builder,
@@ -203,6 +248,10 @@ FailureOr<Value> implementAssignLayout(
203248
if (LayoutAttr layoutAttr = dyn_cast<LayoutAttr>(layout)) {
204249
return implementAssignLayoutNew(input, layoutAttr, ciphertextSize, builder,
205250
createdOpCallback);
251+
} else if (DenseIntElementsAttr elementAttr =
252+
dyn_cast<DenseIntElementsAttr>(layout)) {
253+
return implementAssignLayoutPermutation(input, elementAttr, ciphertextSize,
254+
builder, createdOpCallback);
206255
}
207256
return builder.emitError() << "Unsupported layout attribute type: " << layout;
208257
};

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,17 @@ struct LayoutMaterializationTypeConverter
206206
[this](IndexType type, LayoutAttr attr) -> std::optional<Type> {
207207
return materializeLayout(type, attr, getCiphertextSize());
208208
});
209+
addConversion([this](secret::SecretType type,
210+
DenseIntElementsAttr attr) -> std::optional<Type> {
211+
return secret::SecretType::get(materializePermutationLayout(
212+
getElementTypeOrSelf(type.getValueType()), attr,
213+
getCiphertextSize()));
214+
});
215+
addConversion([this](RankedTensorType type,
216+
DenseIntElementsAttr attr) -> std::optional<Type> {
217+
return materializePermutationLayout(getElementTypeOrSelf(type), attr,
218+
getCiphertextSize());
219+
});
209220
}
210221

211222
int getCiphertextSize() const { return ciphertextSize; }
@@ -352,7 +363,7 @@ class ConvertAssignLayout
352363

353364
// Check cache for existing assign layout function.
354365
Attribute layout = op.getLayout();
355-
if (!isa<LayoutAttr>(layout)) {
366+
if (!isa<LayoutAttr>(layout) && !isa<DenseIntElementsAttr>(layout)) {
356367
return failure();
357368
}
358369
Type inputType = input.getType();

lib/Transforms/ConvertToCiphertextSemantics/TypeConversion.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,13 @@ Type materializeScalarLayout(Type type, LayoutAttr attr, int ciphertextSize) {
4141
return RankedTensorType::get({1, ciphertextSize}, type);
4242
}
4343

44+
Type materializePermutationLayout(Type elementType,
45+
DenseIntElementsAttr permutation,
46+
int ciphertextSize) {
47+
// TODO(#2666): Extend to a more general case where src_ct and dst_ct != 0
48+
// src_ct and dst_ct are always 0; output is always a single ciphertext.
49+
return RankedTensorType::get({1, ciphertextSize}, elementType);
50+
}
51+
4452
} // namespace heir
4553
} // namespace mlir

lib/Transforms/ConvertToCiphertextSemantics/TypeConversion.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ Type materializeLayout(Type dataType, tensor_ext::LayoutAttr attr,
1818
Type materializeScalarLayout(Type type, tensor_ext::LayoutAttr attr,
1919
int ciphertextSize);
2020

21+
// Computes the ciphertext-semantic type for a permutation layout given as a
22+
// <N x 4 x i64> DenseIntElementsAttr of (src_ct, src_slot, dst_ct, dst_slot)
23+
// tuples. The number of ciphertexts is derived from the max dst_ct value.
24+
Type materializePermutationLayout(Type elementType,
25+
DenseIntElementsAttr permutation,
26+
int ciphertextSize);
27+
2128
} // namespace heir
2229
} // namespace mlir
2330

tests/Transforms/convert_to_ciphertext_semantics/assign_layout.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,48 @@ module {
6464
return %0 : !secret.secret<tensor<32xi16>>
6565
}
6666
}
67+
68+
// -----
69+
70+
// CHECK: func.func private @_assign_layout_{{[0-9]+}}
71+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xi16>) -> tensor<1x32xi16>
72+
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
73+
// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
74+
// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
75+
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
76+
// CHECK-DAG: %[[c6:.*]] = arith.constant 6 : index
77+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
78+
// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
79+
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0> : tensor<1x32xi16>
80+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
81+
// CHECK-DAG: %[[EXT0:.*]] = tensor.extract %arg0[%[[c0]], %[[c0]]]
82+
// CHECK-DAG: %[[INS0:.*]] = tensor.insert %[[EXT0]] into %[[cst]][%[[c0]], %[[c7]]]
83+
// CHECK-DAG: %[[EXT1:.*]] = tensor.extract %arg0[%[[c0]], %[[c1]]]
84+
// CHECK-DAG: %[[INS1:.*]] = tensor.insert %[[EXT1]] into %[[INS0]][%[[c0]], %[[c6]]]
85+
// CHECK-DAG: %[[EXT2:.*]] = tensor.extract %arg0[%[[c0]], %[[c2]]]
86+
// CHECK-DAG: %[[INS2:.*]] = tensor.insert %[[EXT2]] into %[[INS1]][%[[c0]], %[[c5]]]
87+
// CHECK-DAG: %[[EXT3:.*]] = tensor.extract %arg0[%[[c0]], %[[c3]]]
88+
// CHECK-DAG: %[[INS3:.*]] = tensor.insert %[[EXT3]] into %[[INS2]][%[[c0]], %[[c4]]]
89+
// CHECK-DAG: %[[EXT4:.*]] = tensor.extract %arg0[%[[c0]], %[[c4]]]
90+
// CHECK-DAG: %[[INS4:.*]] = tensor.insert %[[EXT4]] into %[[INS3]][%[[c0]], %[[c3]]]
91+
// CHECK-DAG: %[[EXT5:.*]] = tensor.extract %arg0[%[[c0]], %[[c5]]]
92+
// CHECK-DAG: %[[INS5:.*]] = tensor.insert %[[EXT5]] into %[[INS4]][%[[c0]], %[[c2]]]
93+
// CHECK-DAG: %[[EXT6:.*]] = tensor.extract %arg0[%[[c0]], %[[c6]]]
94+
// CHECK-DAG: %[[INS6:.*]] = tensor.insert %[[EXT6]] into %[[INS5]][%[[c0]], %[[c1]]]
95+
// CHECK-DAG: %[[EXT7:.*]] = tensor.extract %arg0[%[[c0]], %[[c7]]]
96+
// CHECK-DAG: %[[INS7:.*]] = tensor.insert %[[EXT7]] into %[[INS6]][%[[c0]], %[[c0]]]
97+
98+
// CHECK: @permutate_vector
99+
#layout = dense<[[0, 0, 0, 7], [0, 1, 0, 6], [0, 2, 0, 5], [0, 3, 0, 4], [0, 4, 0, 3], [0, 5, 0, 2], [0, 6, 0, 1], [0, 7, 0, 0]]> : tensor<8x4xi64>
100+
module {
101+
func.func @permutate_vector() {
102+
%cst = arith.constant dense<1> : tensor<1x32xi16>
103+
// CHECK: %[[cst:.*]] = arith.constant dense<1> : tensor<1x32xi16>
104+
// CHECK: func.call @_assign_layout_{{[0-9]+}}(%[[cst]])
105+
%0 = secret.generic() {
106+
%1 = tensor_ext.assign_layout %cst {layout = #layout, tensor_ext.layout = #layout} : tensor<1x32xi16>
107+
secret.yield %1 : tensor<1x32xi16>
108+
} -> (!secret.secret<tensor<1x32xi16>> {tensor_ext.layout = #layout})
109+
return
110+
}
111+
}

0 commit comments

Comments
 (0)