Skip to content

Commit 77aab87

Browse files
j2kuncopybara-github
authored andcommitted
lattigo: fix _ := foo emitter bug
PiperOrigin-RevId: 918498275
1 parent 4056e08 commit 77aab87

2 files changed

Lines changed: 51 additions & 2 deletions

File tree

lib/Target/Lattigo/LattigoEmitter.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,12 @@ LogicalResult LattigoEmitter::printOperation(affine::AffineForOp op) {
328328
for (auto i = 0; i < op.getNumRegionIterArgs(); ++i) {
329329
Value opResult = op.getResult(i);
330330
Value iterArg = op.getRegionIterArgs()[i];
331-
os << getName(opResult) << " := " << getName(iterArg) << "\n";
331+
std::string resultName = getName(opResult);
332+
if (resultName == "_") {
333+
os << resultName << " = " << getName(iterArg) << "\n";
334+
} else {
335+
os << resultName << " := " << getName(iterArg) << "\n";
336+
}
332337
}
333338
return success();
334339
}
@@ -970,7 +975,12 @@ LogicalResult LattigoEmitter::printOperation(scf::ForOp op) {
970975
for (int i = 0; i < op.getNumResults(); ++i) {
971976
Value opResult = op.getResult(i);
972977
Value iterArg = op.getRegionIterArg(i);
973-
os << getName(opResult) << " := " << getName(iterArg) << "\n";
978+
std::string resultName = getName(opResult);
979+
if (resultName == "_") {
980+
os << resultName << " = " << getName(iterArg) << "\n";
981+
} else {
982+
os << resultName << " := " << getName(iterArg) << "\n";
983+
}
974984
}
975985
return success();
976986
}

tests/Emitter/Lattigo/emit_loops.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,43 @@ module attributes {scheme.bgv} {
7373
}
7474
return %1 : !ct
7575
}
76+
// CHECK: test_affine_for_unused_result
77+
// CHECK-SAME: [[ct_init:[^ ]*]] *rlwe.Ciphertext
78+
// CHECK: [[iter_arg:[^ ]*]] := [[ct_init]]
79+
// CHECK: for [[induction_var:[^ ]*]] := 1; [[induction_var]] < 10; [[induction_var]] += 2 {
80+
// CHECK: [[op_result:[^ ,]*]], [[err0:[^ ]*]] := evaluator.RotateColumnsNew([[iter_arg]], 1)
81+
// CHECK: [[iter_arg]] = [[op_result]]
82+
// CHECK: }
83+
// CHECK: _ = [[iter_arg]]
84+
// CHECK: return
85+
func.func @test_affine_for_unused_result(%evaluator: !evaluator, %ct: !ct) {
86+
%1 = affine.for %arg0 = 1 to 10 step 2 iter_args(%arg1 = %ct) -> (!ct) {
87+
%ct_12 = lattigo.bgv.rotate_columns_new %evaluator, %arg1 {static_shift = 1} : (!evaluator, !ct) -> !ct
88+
affine.yield %ct_12 : !ct
89+
}
90+
return
91+
}
92+
93+
// CHECK: test_scf_for_unused_result
94+
// CHECK-SAME: [[ct_init:[^ ]*]] *rlwe.Ciphertext
95+
// CHECK: [[c1:[^ ]*]] := int64(1)
96+
// CHECK: [[c10:[^ ]*]] := int64(10)
97+
// CHECK: [[c2:[^ ]*]] := int64(2)
98+
// CHECK: [[iter_arg:[^ ]*]] := [[ct_init]]
99+
// CHECK: for [[induction_var:[^ ]*]] := [[c1]]; [[induction_var]] < [[c10]]; [[induction_var]] += [[c2]] {
100+
// CHECK: [[op_result:[^ ,]*]], [[err0:[^ ]*]] := evaluator.RotateColumnsNew([[iter_arg]], 1)
101+
// CHECK: [[iter_arg]] = [[op_result]]
102+
// CHECK: }
103+
// CHECK: _ = [[iter_arg]]
104+
// CHECK: return
105+
func.func @test_scf_for_unused_result(%evaluator: !evaluator, %ct: !ct) {
106+
%c1 = arith.constant 1 : index
107+
%c10 = arith.constant 10 : index
108+
%c2 = arith.constant 2 : index
109+
%1 = scf.for %arg0 = %c1 to %c10 step %c2 iter_args(%arg1 = %ct) -> (!ct) {
110+
%ct_12 = lattigo.bgv.rotate_columns_new %evaluator, %arg1 {static_shift = 1} : (!evaluator, !ct) -> !ct
111+
scf.yield %ct_12 : !ct
112+
}
113+
return
114+
}
76115
}

0 commit comments

Comments
 (0)