Skip to content

Commit 0e5fdd8

Browse files
asraacopybara-github
authored andcommitted
regression test for in-place issue #2635
PiperOrigin-RevId: 868309124
1 parent 789b880 commit 0e5fdd8

3 files changed

Lines changed: 199 additions & 0 deletions

File tree

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# See README.md for setup required to run these tests
2+
3+
load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib")
4+
load("@rules_go//go:def.bzl", "go_test")
5+
6+
package(default_applicable_licenses = ["@heir//:license"])
7+
8+
heir_lattigo_lib(
9+
name = "in_place",
10+
go_library_name = "main",
11+
heir_opt_flags = [
12+
"--scheme-to-lattigo",
13+
],
14+
mlir_src = "in_place.mlir",
15+
)
16+
17+
# For Google-internal reasons we must separate the go_test rules from the macro
18+
# above.
19+
20+
go_test(
21+
name = "in_place_test",
22+
srcs = ["in_place_test.go"],
23+
embed = [":main"],
24+
deps = [
25+
"@com_github_tuneinsight_lattigo_v6//core/rlwe",
26+
"@com_github_tuneinsight_lattigo_v6//schemes/ckks",
27+
],
28+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
!Z1056763241666817029_i64 = !mod_arith.int<1056763241666817029 : i64>
2+
!Z1106058412451299513_i64 = !mod_arith.int<1106058412451299513 : i64>
3+
!Z957769724367225479_i64 = !mod_arith.int<957769724367225479 : i64>
4+
#inverse_canonical_encoding = #lwe.inverse_canonical_encoding<scaling_factor = 60>
5+
#inverse_canonical_encoding1 = #lwe.inverse_canonical_encoding<scaling_factor = 40>
6+
#inverse_canonical_encoding2 = #lwe.inverse_canonical_encoding<scaling_factor = 100>
7+
#key = #lwe.key<>
8+
#modulus_chain_L1_C1 = #lwe.modulus_chain<elements = <1106058412451299513 : i64, 1056763241666817029 : i64>, current = 1>
9+
#modulus_chain_L2_C2 = #lwe.modulus_chain<elements = <1106058412451299513 : i64, 1056763241666817029 : i64, 957769724367225479 : i64>, current = 2>
10+
#ring_f64_1_x131072 = #polynomial.ring<coefficientType = f64, polynomialModulus = <1 + x**131072>>
11+
!rns_L1 = !rns.rns<!Z1106058412451299513_i64, !Z1056763241666817029_i64>
12+
!rns_L2 = !rns.rns<!Z1106058412451299513_i64, !Z1056763241666817029_i64, !Z957769724367225479_i64>
13+
!pt = !lwe.lwe_plaintext<application_data = <message_type = tensor<65536xf64>>, plaintext_space = <ring = #ring_f64_1_x131072, encoding = #inverse_canonical_encoding1>>
14+
#ring_rns_L1_1_x131072 = #polynomial.ring<coefficientType = !rns_L1, polynomialModulus = <1 + x**131072>>
15+
#ring_rns_L2_1_x131072 = #polynomial.ring<coefficientType = !rns_L2, polynomialModulus = <1 + x**131072>>
16+
#ciphertext_space_L1 = #lwe.ciphertext_space<ring = #ring_rns_L1_1_x131072, encryption_type = mix>
17+
#ciphertext_space_L2 = #lwe.ciphertext_space<ring = #ring_rns_L2_1_x131072, encryption_type = mix>
18+
!ct_L1 = !lwe.lwe_ciphertext<application_data = <message_type = tensor<65536xf64>>, plaintext_space = <ring = #ring_f64_1_x131072, encoding = #inverse_canonical_encoding1>, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1>
19+
!ct_L2 = !lwe.lwe_ciphertext<application_data = <message_type = tensor<65536xf64>>, plaintext_space = <ring = #ring_f64_1_x131072, encoding = #inverse_canonical_encoding>, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L2_C2>
20+
!ct_L2_1 = !lwe.lwe_ciphertext<application_data = <message_type = tensor<65536xf64>>, plaintext_space = <ring = #ring_f64_1_x131072, encoding = #inverse_canonical_encoding2>, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L2_C2>
21+
module attributes {ckks.schemeParam = #ckks.scheme_param<logN = 17, Q = [1106058412451299513, 1056763241666817029, 957769724367225479, 919081519653443687, 1030837924888066153, 1084354410096143723, 1135846243351935917, 1087115004561311021, 997960547764032911, 892538949448853293, 1002528331340998513, 1100798419621231379, 981696679688787961, 1061922508412786269], P = [1152921504606846976], logDefaultScale = 60>, scheme.ckks} {
22+
func.func @in_place(%ct: !ct_L2) -> !ct_L1 {
23+
%ct_0 = ckks.rotate %ct {offset = 0 : i32} : !ct_L2
24+
%cst = arith.constant dense<0.000000e+00> : tensor<65536xf64>
25+
%pt = lwe.rlwe_encode %cst {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
26+
%ct_1 = ckks.mul_plain %ct_0, %pt : (!ct_L2, !pt) -> !ct_L2_1
27+
%ct_2 = ckks.rescale %ct_1 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
28+
%ct_3 = ckks.rotate %ct {offset = 1 : i32} : !ct_L2
29+
%cst_4 = arith.constant dense<0.000000e+00> : tensor<65536xf64>
30+
%pt_5 = lwe.rlwe_encode %cst_4 {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
31+
%ct_6 = ckks.mul_plain %ct_3, %pt_5 : (!ct_L2, !pt) -> !ct_L2_1
32+
%ct_7 = ckks.rescale %ct_6 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
33+
%ct_8 = ckks.add %ct_2, %ct_7 : (!ct_L1, !ct_L1) -> !ct_L1
34+
%ct_9 = ckks.rotate %ct {offset = 2 : i32} : !ct_L2
35+
%cst_10 = arith.constant dense<0.000000e+00> : tensor<65536xf64>
36+
%pt_11 = lwe.rlwe_encode %cst_10 {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
37+
%ct_12 = ckks.mul_plain %ct_9, %pt_11 : (!ct_L2, !pt) -> !ct_L2_1
38+
%ct_13 = ckks.rescale %ct_12 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
39+
%ct_14 = ckks.add %ct_8, %ct_13 : (!ct_L1, !ct_L1) -> !ct_L1
40+
%ct_15 = ckks.rotate %ct {offset = 3 : i32} : !ct_L2
41+
%cst_16 = arith.constant dense<0.000000e+00> : tensor<65536xf64>
42+
%pt_17 = lwe.rlwe_encode %cst_16 {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
43+
%ct_18 = ckks.mul_plain %ct_15, %pt_17 : (!ct_L2, !pt) -> !ct_L2_1
44+
%ct_19 = ckks.rescale %ct_18 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
45+
%ct_20 = ckks.add %ct_14, %ct_19 : (!ct_L1, !ct_L1) -> !ct_L1
46+
%ct_21 = ckks.rotate %ct {offset = 4 : i32} : !ct_L2
47+
%cst_22 = arith.constant dense<0.000000e+00> : tensor<65536xf64>
48+
%pt_23 = lwe.rlwe_encode %cst_22 {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
49+
%ct_24 = ckks.mul_plain %ct_21, %pt_23 : (!ct_L2, !pt) -> !ct_L2_1
50+
%ct_25 = ckks.rescale %ct_24 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
51+
%ct_26 = ckks.add %ct_20, %ct_25 : (!ct_L1, !ct_L1) -> !ct_L1
52+
%ct_27 = ckks.rotate %ct {offset = 5 : i32} : !ct_L2
53+
%cst_28 = arith.constant dense<0.000000e+00> : tensor<65536xf64>
54+
%pt_29 = lwe.rlwe_encode %cst_28 {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
55+
%ct_30 = ckks.mul_plain %ct_27, %pt_29 : (!ct_L2, !pt) -> !ct_L2_1
56+
%ct_31 = ckks.rescale %ct_30 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
57+
%ct_32 = ckks.add %ct_26, %ct_31 : (!ct_L1, !ct_L1) -> !ct_L1
58+
%ct_33 = ckks.rotate %ct {offset = 6 : i32} : !ct_L2
59+
%cst_34 = arith.constant dense<0.000000e+00> : tensor<65536xf64>
60+
%pt_35 = lwe.rlwe_encode %cst_34 {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt
61+
%ct_36 = ckks.mul_plain %ct_33, %pt_35 : (!ct_L2, !pt) -> !ct_L2_1
62+
%ct_37 = ckks.rescale %ct_36 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1
63+
%ct_38 = ckks.add %ct_32, %ct_37 : (!ct_L1, !ct_L1) -> !ct_L1
64+
return %ct_38 : !ct_L1
65+
}
66+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
"time"
7+
8+
"github.com/tuneinsight/lattigo/v6/core/rlwe"
9+
"github.com/tuneinsight/lattigo/v6/schemes/ckks"
10+
)
11+
12+
// MakeFlattenedOnes creates a slice of float64 filled with 1.0s.
13+
// The size of the slice is determined by the product of the input 2D dimensions (rows * cols).
14+
func MakeFlattenedOnes(rows, cols int) []float64 {
15+
size := rows * cols
16+
tensor := make([]float64, size)
17+
for i := range tensor {
18+
tensor[i] = 1.0
19+
}
20+
return tensor
21+
}
22+
23+
func makeRange(n int) []int {
24+
a := make([]int, n)
25+
for i := range a {
26+
a[i] = i
27+
}
28+
return a
29+
}
30+
31+
func generateGalEls(param ckks.Parameters, indices []int) []uint64 {
32+
var galEls []uint64
33+
for _, index := range indices {
34+
galEls = append(galEls, param.GaloisElement(index))
35+
}
36+
return galEls
37+
}
38+
39+
func TestMLP(t *testing.T) {
40+
logN := 14
41+
numSlots := 1 << (logN - 1)
42+
43+
// Input is arbitrary, doesn't matter since we're just testing
44+
// performance
45+
inputClear := make([]float64, numSlots)
46+
for i := range inputClear {
47+
inputClear[i] = 1.0
48+
}
49+
50+
// Function args:
51+
//
52+
// %ct: encrypted input,
53+
54+
// These parameters should match the mlir file, though due to the weird
55+
// nature of this test, this is the source of truth for what is used,
56+
// not the mlir file.
57+
logQ := make([]int, 7)
58+
for i := range logQ {
59+
logQ[i] = 60
60+
}
61+
param, err := ckks.NewParametersFromLiteral(ckks.ParametersLiteral{
62+
LogN: logN,
63+
LogQ: logQ,
64+
LogP: []int{60},
65+
LogDefaultScale: 40,
66+
})
67+
if err != nil {
68+
panic(err)
69+
}
70+
71+
encoder := ckks.NewEncoder(param)
72+
kgen := rlwe.NewKeyGenerator(param)
73+
sk, pk := kgen.GenKeyPairNew()
74+
encryptor := rlwe.NewEncryptor(param, pk)
75+
rk := kgen.GenRelinearizationKeyNew(sk)
76+
77+
// We have to do this once for each distinct linear_transform op to
78+
// ensure we generate all the galois keys needed by lattigo
79+
var galEls []uint64
80+
// Manually add Galois key for extra rotation indices used in the
81+
// mlir file, outside of linear_transform
82+
//
83+
// For some reason I need to manually add rotation keys used in
84+
// linear_transform! That should have been handled by the above code...
85+
rotIndices := makeRange(10)
86+
galEls = append(galEls, generateGalEls(param, rotIndices)...)
87+
88+
fmt.Printf("Final galEls: %v\n", galEls)
89+
90+
evk := rlwe.NewMemEvaluationKeySet(rk, kgen.GenGaloisKeysNew(galEls, sk)...)
91+
evaluator := ckks.NewEvaluator(param, evk)
92+
93+
pt := ckks.NewPlaintext(param, 2)
94+
encoder.Encode(inputClear, pt)
95+
ctInput, err25 := encryptor.EncryptNew(pt)
96+
if err25 != nil {
97+
panic(err25)
98+
}
99+
100+
fmt.Printf("Starting call")
101+
startTime := time.Now()
102+
in_place(evaluator, param, encoder, ctInput)
103+
duration := time.Since(startTime)
104+
fmt.Printf("MLP call took: %v\n", duration)
105+
}

0 commit comments

Comments
 (0)