Skip to content

Commit 777fd54

Browse files
gcarranza-1xnnpack-bot
authored andcommitted
Migrate LiteRT ATS unary op graph generation to use litert::tensor API.
PiperOrigin-RevId: 917866082
1 parent ce14e18 commit 777fd54

3 files changed

Lines changed: 52 additions & 0 deletions

File tree

litert/tensor/arithmetic.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,30 @@ Tensor<Mixins...> Relu6(Tensor<Mixins...> a,
165165
return output;
166166
}
167167

168+
template <class... Mixins>
169+
Tensor<Mixins...> ReluN1To1(Tensor<Mixins...> a,
170+
source_location loc = source_location::current()) {
171+
Tensor<Mixins...> output =
172+
ElementwiseOp<graph::ReluN1To1Operation<Mixins...>>(loc, a);
173+
return output;
174+
}
175+
176+
template <class... Mixins>
177+
Tensor<Mixins...> ZerosLike(Tensor<Mixins...> a,
178+
source_location loc = source_location::current()) {
179+
Tensor<Mixins...> output =
180+
ElementwiseOp<graph::ZerosLikeOperation<Mixins...>>(loc, a);
181+
return output;
182+
}
183+
184+
template <class... Mixins>
185+
Tensor<Mixins...> Relu0To1(Tensor<Mixins...> a,
186+
source_location loc = source_location::current()) {
187+
Tensor<Mixins...> output =
188+
ElementwiseOp<graph::Relu0To1Operation<Mixins...>>(loc, a);
189+
return output;
190+
}
191+
168192
template <class... Mixins>
169193
Tensor<Mixins...> LeakyRelu(Tensor<Mixins...> a, float alpha = 0.2f,
170194
source_location loc = source_location::current()) {

litert/tensor/arithmetic_graph.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,27 @@ struct Relu6Operation : virtual Operation,
8383
absl::string_view GetName() const override { return "Relu6"; }
8484
};
8585

86+
template <class... Mixins>
87+
struct ReluN1To1Operation
88+
: virtual Operation,
89+
virtual OpMixin<struct ReluN1To1OperationTag, Mixins>... {
90+
absl::string_view GetName() const override { return "ReluN1To1"; }
91+
};
92+
93+
template <class... Mixins>
94+
struct ZerosLikeOperation
95+
: virtual Operation,
96+
virtual OpMixin<struct ZerosLikeOperationTag, Mixins>... {
97+
absl::string_view GetName() const override { return "ZerosLike"; }
98+
};
99+
100+
template <class... Mixins>
101+
struct Relu0To1Operation
102+
: virtual Operation,
103+
virtual OpMixin<struct Relu0To1OperationTag, Mixins>... {
104+
absl::string_view GetName() const override { return "Relu0To1"; }
105+
};
106+
86107
struct LeakyReluOperationData {
87108
float alpha;
88109
};

litert/tensor/datatypes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ limitations under the License.
2222

2323
#include "litert/tensor/internal/fp16.h"
2424

25+
namespace tflite {
26+
struct half;
27+
}
28+
2529
namespace litert::tensor {
2630

2731
enum class Type {
@@ -266,6 +270,9 @@ struct ApiType<bf16_t> : internal::StorageImpl<Type::kBF16, bf16_t> {};
266270
template <>
267271
struct ApiType<fp16_t> : internal::StorageImpl<Type::kFP16, fp16_t> {};
268272

273+
template <>
274+
struct ApiType<tflite::half> : internal::StorageImpl<Type::kFP16, fp16_t> {};
275+
269276
template <>
270277
struct ApiType<float> : internal::StorageImpl<Type::kFP32, float> {};
271278

0 commit comments

Comments
 (0)