Skip to content

Commit 0dbe09a

Browse files
Aelphycopybara-github
authored andcommitted
Eradicate Ruy from TFLite CPU Backend for GEMM.
PiperOrigin-RevId: 893726969
1 parent e136113 commit 0dbe09a

3 files changed

Lines changed: 145 additions & 46 deletions

File tree

tflite/kernels/BUILD

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ cc_library(
367367
copts = tflite_copts(),
368368
deps = [
369369
":cpu_backend_context",
370-
":tflite_with_ruy",
371370
"//tflite/kernels/internal:compatibility",
372371
# For now this unconditionally depends on both ruy and gemmlowp.
373372
# We only need to depend on gemmlowp when tflite_with_ruy
@@ -401,11 +400,13 @@ cc_library(
401400
hdrs = [
402401
"cpu_backend_gemm.h",
403402
"cpu_backend_gemm_params.h",
403+
"cpu_backend_gemm_reference.h",
404404
"cpu_backend_gemm_ruy.h",
405405
],
406406
compatible_with = get_compatible_with_portable(),
407407
copts = tflite_copts(),
408408
deps = [
409+
"//tflite:minimal_logging",
409410
":tflite_with_ruy",
410411
"//tflite/kernels/internal:common",
411412
"//tflite/kernels/internal:compatibility",
@@ -488,7 +489,6 @@ cc_test(
488489
tags = ["tflite_smoke_test"],
489490
deps = [
490491
":rng_util",
491-
":test_util",
492492
"@com_google_googletest//:gtest_main",
493493
],
494494
)
@@ -863,7 +863,6 @@ cc_library(
863863
"//tflite/core/c:common",
864864
"//tflite/experimental/resource",
865865
"//tflite/kernels/internal:tensor",
866-
"@flatbuffers",
867866
],
868867
)
869868

@@ -899,9 +898,7 @@ pybind_extension(
899898
],
900899
deps = [
901900
"//tflite:framework_stable",
902-
"//tflite:mutable_op_resolver",
903901
"@pybind11",
904-
"@xla//third_party/python_runtime:headers",
905902
],
906903
)
907904

@@ -913,10 +910,9 @@ cc_library(
913910
copts = tflite_copts(),
914911
visibility = ["//visibility:private"],
915912
deps = [
916-
"//tflite/kernels:cpu_backend_context",
913+
":cpu_backend_context",
917914
"//tflite/kernels/internal:optimized_base",
918915
"//tflite/kernels/internal:tensor",
919-
"@eigen_archive//:eigen3",
920916
],
921917
)
922918

@@ -938,11 +934,11 @@ cc_library(
938934
compatible_with = get_compatible_with_portable(),
939935
copts = tflite_copts(),
940936
deps = [
937+
":cpu_backend_context",
941938
":gru_cell",
942939
":kernel_util",
940+
":padding",
943941
"//tflite/core/c:common",
944-
"//tflite/kernels:cpu_backend_context",
945-
"//tflite/kernels:padding",
946942
"//tflite/kernels/internal:common",
947943
"//tflite/kernels/internal:compatibility",
948944
"//tflite/kernels/internal:reference_base",
@@ -1052,7 +1048,6 @@ cc_library(
10521048
compatible_with = get_compatible_with_portable(),
10531049
copts = tflite_copts(),
10541050
deps = [
1055-
":builtin_op_kernels",
10561051
"//tflite:framework_stable",
10571052
"//tflite:util",
10581053
"//tflite/core/c:common",
@@ -1470,7 +1465,6 @@ cc_test(
14701465
":test_main",
14711466
":test_util",
14721467
"//tflite/schema:schema_fbs",
1473-
"//tflite/testing:util",
14741468
"@com_google_googletest//:gtest",
14751469
],
14761470
)
@@ -1507,7 +1501,6 @@ cc_test(
15071501
"@com_google_absl//absl/random",
15081502
"@com_google_absl//absl/types:span",
15091503
"@com_google_googletest//:gtest",
1510-
"@eigen_archive//:eigen3",
15111504
],
15121505
)
15131506

@@ -1574,9 +1567,7 @@ cc_test(
15741567
":test_main",
15751568
":test_util",
15761569
"//tflite:framework_stable",
1577-
"//tflite:string",
15781570
"//tflite/schema:schema_fbs",
1579-
"@com_google_absl//absl/memory",
15801571
"@com_google_googletest//:gtest",
15811572
],
15821573
)
@@ -1646,13 +1637,9 @@ cc_test(
16461637
":test_util",
16471638
"//tflite:framework_stable",
16481639
"//tflite/core:framework_stable",
1649-
"//tflite/core/api",
1650-
"//tflite/kernels/internal:types",
16511640
"//tflite/schema:schema_fbs",
1652-
"@com_google_absl//absl/memory",
16531641
"@com_google_googletest//:gtest",
16541642
"@eigen_archive//:eigen3",
1655-
"@flatbuffers",
16561643
],
16571644
)
16581645

@@ -1874,7 +1861,6 @@ cc_library(
18741861
],
18751862
deps = [
18761863
":test_util",
1877-
"//tflite:string",
18781864
"//tflite/schema:schema_fbs",
18791865
],
18801866
)
@@ -1963,8 +1949,6 @@ cc_test(
19631949
":test_main",
19641950
":test_util",
19651951
"//tflite/kernels/internal:types",
1966-
"//tflite/schema:schema_fbs",
1967-
"//tflite/testing:util",
19681952
"@com_google_googletest//:gtest",
19691953
],
19701954
)
@@ -2051,7 +2035,6 @@ cc_test(
20512035
"//tflite/core/c:common",
20522036
"//tflite/schema:schema_fbs",
20532037
"@com_google_googletest//:gtest",
2054-
"@flatbuffers",
20552038
],
20562039
)
20572040

@@ -2292,7 +2275,6 @@ cc_test(
22922275
":custom_ops",
22932276
":test_main",
22942277
":test_util",
2295-
"//tflite/schema:schema_fbs",
22962278
"@com_google_googletest//:gtest",
22972279
"@flatbuffers",
22982280
],
@@ -2504,7 +2486,6 @@ cc_test(
25042486
size = "small",
25052487
srcs = ["table_test.cc"],
25062488
deps = [
2507-
":custom_ops",
25082489
":test_main",
25092490
":test_util",
25102491
"//tflite/kernels/internal:common",
@@ -2843,7 +2824,6 @@ cc_test(
28432824
":kernel_util",
28442825
":subgraph_test_util",
28452826
":test_main",
2846-
":variable_op_kernels",
28472827
"//tflite:framework_stable",
28482828
"//tflite/core:framework_stable",
28492829
"@com_google_googletest//:gtest",
@@ -2863,7 +2843,6 @@ cc_test(
28632843
"//tflite/core:framework_stable",
28642844
"//tflite/delegates/xnnpack:xnnpack_delegate",
28652845
"//tflite/kernels/internal:tensor",
2866-
"//tflite/profiling:memory_info",
28672846
"@com_google_googletest//:gtest",
28682847
],
28692848
)
@@ -3141,18 +3120,15 @@ cc_test(
31413120
"hashtable_ops_test.cc",
31423121
],
31433122
deps = [
3123+
":test_main",
3124+
":test_util",
31443125
"//tflite:framework_stable",
31453126
"//tflite/core:framework_stable",
3146-
"//tflite/core/api",
31473127
"//tflite/experimental/resource",
3148-
"//tflite/kernels:test_main",
3149-
"//tflite/kernels:test_util",
31503128
"//tflite/kernels/internal:tensor",
31513129
"//tflite/testing:util",
31523130
"@com_google_absl//absl/memory",
31533131
"@com_google_absl//absl/strings",
3154-
"@com_google_googletest//:gtest",
3155-
"@flatbuffers",
31563132
],
31573133
)
31583134

@@ -3162,9 +3138,8 @@ cc_test(
31623138
srcs = ["unidirectional_sequence_gru_test.cc"],
31633139
tags = ["tflite_not_portable_ios"],
31643140
deps = [
3165-
":custom_ops",
3141+
":test_util",
31663142
"//tflite:framework_stable",
3167-
"//tflite/kernels:test_util",
31683143
"@com_google_googletest//:gtest_main",
31693144
],
31703145
)
@@ -3177,9 +3152,7 @@ cc_test(
31773152
":test_main",
31783153
":test_util",
31793154
"//tflite:framework_stable",
3180-
"//tflite:string",
31813155
"//tflite/schema:schema_fbs",
3182-
"@com_google_absl//absl/memory",
31833156
"@com_google_googletest//:gtest",
31843157
],
31853158
)
@@ -3221,7 +3194,6 @@ cc_test(
32213194
size = "small",
32223195
srcs = ["stablehlo_add_test.cc"],
32233196
deps = [
3224-
":subgraph_test_util",
32253197
":test_util",
32263198
"//tflite/c:c_api_types",
32273199
"//tflite/c:common",
@@ -3237,7 +3209,6 @@ cc_test(
32373209
size = "small",
32383210
srcs = ["stablehlo_multiply_test.cc"],
32393211
deps = [
3240-
":subgraph_test_util",
32413212
":test_util",
32423213
"//tflite/c:c_api_types",
32433214
"//tflite/c:common",
@@ -3438,6 +3409,16 @@ cc_test(
34383409
],
34393410
)
34403411

3412+
cc_test(
3413+
name = "unsorted_segment_min_test_cc",
3414+
srcs = ["unsorted_segment_min_test.cc"],
3415+
deps = [
3416+
":test_util",
3417+
"//tflite/schema:schema_fbs",
3418+
"@com_google_googletest//:gtest_main",
3419+
],
3420+
)
3421+
34413422
tflite_portable_test_suite_combined(
34423423
combine_conditions = {"deps": [":test_main"]},
34433424
# TODO(b/229985981) : Remove `nnapi_args` after adding Relu0To1 is completed.

tflite/kernels/cpu_backend_gemm.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222
#include "tflite/kernels/cpu_backend_context.h"
2323
#include "tflite/kernels/cpu_backend_gemm_custom_gemv.h"
2424
#include "tflite/kernels/cpu_backend_gemm_params.h"
25-
#include "tflite/kernels/cpu_backend_gemm_ruy.h"
25+
#include "tflite/kernels/cpu_backend_gemm_reference.h"
2626

2727
#ifndef TFLITE_WITH_RUY
2828
#include "tflite/kernels/cpu_backend_gemm_eigen.h"
@@ -66,8 +66,9 @@ struct GemmImpl : detail::GemmImplX86<LhsScalar, RhsScalar, AccumScalar,
6666
*/
6767
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
6868
typename DstScalar, QuantizationFlavor quantization_flavor>
69-
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
70-
DstScalar, quantization_flavor> {};
69+
struct GemmImpl
70+
: detail::GemmImplUsingReference<LhsScalar, RhsScalar, AccumScalar,
71+
DstScalar, quantization_flavor> {};
7172

7273
#if !defined(TFLITE_WITH_RUY)
7374

@@ -86,20 +87,20 @@ struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
8687
template <typename SrcScalar, QuantizationFlavor quantization_flavor>
8788
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
8889
quantization_flavor>
89-
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
90-
quantization_flavor> {};
90+
: detail::GemmImplUsingReference<SrcScalar, SrcScalar, std::int32_t,
91+
std::int8_t, quantization_flavor> {};
9192

9293
template <typename DstScalar, QuantizationFlavor quantization_flavor>
9394
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
9495
quantization_flavor>
95-
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
96-
DstScalar, quantization_flavor> {};
96+
: detail::GemmImplUsingReference<std::int8_t, std::int8_t, std::int32_t,
97+
DstScalar, quantization_flavor> {};
9798

9899
template <QuantizationFlavor quantization_flavor>
99100
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
100101
quantization_flavor>
101-
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
102-
std::int8_t, quantization_flavor> {};
102+
: detail::GemmImplUsingReference<std::int8_t, std::int8_t, std::int32_t,
103+
std::int8_t, quantization_flavor> {};
103104
#endif // not GEMMLOWP_NEON
104105

105106
/* Specializations using Eigen */

0 commit comments

Comments
 (0)