@@ -195,6 +195,51 @@ static FailureOr<Value> implementUnpackOpNew(
195195 return loop.value ().getResults ()[0 ];
196196}
197197
198+ static FailureOr<Value> implementAssignLayoutPermutation (
199+ Value input, DenseIntElementsAttr permutation, int64_t ciphertextSize,
200+ ImplicitLocOpBuilder& builder,
201+ const std::function<void (Operation*)>& createdOpCallback) {
202+ auto elementType = getElementTypeOrSelf (input.getType ());
203+
204+ // TODO(#2666): This logic assumes ctxt = 0, this can be extended to handle a
205+ // more general case. permutation is <N x 4 x i64>: each row is [src_ct,
206+ // src_slot, dst_ct, dst_slot]. src_ct and dst_ct are always 0 (single
207+ // ciphertext), so treat as 1D: extract from input[0][src_slot] and insert
208+ // into result[0][dst_slot].
209+ auto ciphertextType = RankedTensorType::get ({1 , ciphertextSize}, elementType);
210+ auto zeroCtxt = arith::ConstantOp::create (
211+ builder, ciphertextType, builder.getZeroAttr (ciphertextType));
212+ createdOpCallback (zeroCtxt);
213+ Value result = zeroCtxt.getResult ();
214+ auto ctIdxOp = arith::ConstantIndexOp::create (builder, 0 );
215+ createdOpCallback (ctIdxOp);
216+ Value ctIdx = ctIdxOp.getResult ();
217+
218+ for (auto it = permutation.value_begin <APInt>();
219+ it != permutation.value_end <APInt>();) {
220+ // skip src_ct (always 0)
221+ ++it;
222+ int64_t srcSlot = (*it++).getSExtValue ();
223+ // skip dst_ct (always 0)
224+ ++it;
225+ int64_t dstSlot = (*it++).getSExtValue ();
226+
227+ auto srcSlotIdx = arith::ConstantIndexOp::create (builder, srcSlot);
228+ createdOpCallback (srcSlotIdx);
229+ auto extracted = tensor::ExtractOp::create (builder, input,
230+ ValueRange{ctIdx, srcSlotIdx});
231+ createdOpCallback (extracted);
232+ auto dstSlotIdx = arith::ConstantIndexOp::create (builder, dstSlot);
233+ createdOpCallback (dstSlotIdx);
234+ auto insertOp = tensor::InsertOp::create (
235+ builder, extracted.getResult (), result, ValueRange{ctIdx, dstSlotIdx});
236+ createdOpCallback (insertOp);
237+ result = insertOp.getResult ();
238+ }
239+
240+ return result;
241+ }
242+
198243FailureOr<Value> implementAssignLayout (
199244 Value input, Attribute layout, int64_t ciphertextSize,
200245 ImplicitLocOpBuilder& builder,
@@ -203,6 +248,10 @@ FailureOr<Value> implementAssignLayout(
203248 if (LayoutAttr layoutAttr = dyn_cast<LayoutAttr>(layout)) {
204249 return implementAssignLayoutNew (input, layoutAttr, ciphertextSize, builder,
205250 createdOpCallback);
251+ } else if (DenseIntElementsAttr elementAttr =
252+ dyn_cast<DenseIntElementsAttr>(layout)) {
253+ return implementAssignLayoutPermutation (input, elementAttr, ciphertextSize,
254+ builder, createdOpCallback);
206255 }
207256 return builder.emitError () << " Unsupported layout attribute type: " << layout;
208257};
0 commit comments