Skip to content

Commit 26f3db4

Browse files
terryheocopybara-github
authored andcommitted
Update CompiledModel::CreateInputBuffers() and CompiledModel::CreateOutputBuffers()
Fix them to respect alignment in the TensorBufferRequirements by using LiteRtCreateManagedTensorBufferFromRequirements() C API. This PR fixes #5373. LiteRT-PiperOrigin-RevId: 893695151
1 parent 2408dda commit 26f3db4

9 files changed

Lines changed: 136 additions & 39 deletions

litert/cc/BUILD

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ cc_library(
143143
hdrs = ["litert_expected.h"],
144144
compatible_with = get_compatible_with_portable(),
145145
deps = [
146+
":litert_common",
146147
"//litert/c:litert_common",
147-
"//litert/cc:litert_common",
148148
"@com_google_absl//absl/log:absl_check",
149149
"@com_google_absl//absl/strings",
150150
"@com_google_absl//absl/strings:str_format",
@@ -210,11 +210,11 @@ cc_library(
210210
srcs = ["litert_opaque_options.cc"],
211211
hdrs = ["litert_opaque_options.h"],
212212
deps = [
213+
":litert_common",
213214
":litert_expected",
214215
":litert_macros",
215216
"//litert/c:litert_common",
216217
"//litert/c:litert_opaque_options",
217-
"//litert/cc:litert_common",
218218
"//litert/cc/internal:litert_handle",
219219
"@com_google_absl//absl/strings:string_view",
220220
],
@@ -482,6 +482,15 @@ litert_device_test(
482482
# "requires-gpu-nvidia",
483483
# ],
484484
# deps = [
485+
# ":litert_common",
486+
# ":litert_compiled_model",
487+
# ":litert_environment",
488+
# ":litert_event",
489+
# ":litert_expected",
490+
# ":litert_macros",
491+
# ":litert_model",
492+
# ":litert_options",
493+
# ":litert_tensor_buffer",
485494
# "@com_google_googletest//:gtest_main",
486495
# "@com_google_absl//absl/debugging:leak_check",
487496
# "@com_google_absl//absl/log:absl_log",
@@ -493,15 +502,6 @@ litert_device_test(
493502
# "//litert/c:litert_event_type",
494503
# "//litert/c:litert_profiler_event",
495504
# "//litert/c:litert_tensor_buffer_types",
496-
# "//litert/cc:litert_common",
497-
# "//litert/cc:litert_compiled_model",
498-
# "//litert/cc:litert_environment",
499-
# "//litert/cc:litert_event",
500-
# "//litert/cc:litert_expected",
501-
# "//litert/cc:litert_macros",
502-
# "//litert/cc:litert_model",
503-
# "//litert/cc:litert_options",
504-
# "//litert/cc:litert_tensor_buffer",
505505
# "//litert/cc/internal:litert_platform_support",
506506
# "//litert/cc/options:litert_gpu_options",
507507
# "//litert/cc/options:litert_runtime_options",
@@ -629,14 +629,14 @@ cc_library(
629629
# copybara:comment_end
630630
],
631631
deps = [
632+
":litert_common",
632633
":litert_expected",
633634
":litert_layout",
634635
"//litert/c:litert_common",
635636
"//litert/c:litert_custom_op_kernel",
636637
"//litert/c:litert_layout",
637638
"//litert/c:litert_tensor_buffer",
638639
"//litert/c/internal:litert_logging",
639-
"//litert/cc:litert_common",
640640
"//litert/cc/internal:litert_handle",
641641
"//litert/cc/internal:litert_tensor_buffer_without_registry",
642642
"@com_google_absl//absl/types:span",
@@ -780,14 +780,14 @@ cc_library(
780780
# copybara:comment_end
781781
],
782782
deps = [
783+
":litert_environment",
783784
":litert_expected",
784785
":litert_macros",
785786
"//litert/c:litert_common",
786787
"//litert/c:litert_event_type",
787788
"//litert/c:litert_gl_types",
788789
"//litert/c:litert_opencl_types",
789790
"//litert/c:litert_profiler_event",
790-
"//litert/cc:litert_environment",
791791
"//litert/cc/internal:litert_handle",
792792
],
793793
)
@@ -797,9 +797,9 @@ cc_test(
797797
srcs = ["litert_event_test.cc"],
798798
linkopts = litert_android_linkopts() + gles_linkopts(),
799799
deps = [
800+
":litert_environment",
800801
":litert_event",
801802
"//litert/c:litert_event_type",
802-
"//litert/cc:litert_environment",
803803
"//litert/cc/internal:litert_platform_support",
804804
"//litert/test:matchers",
805805
"@com_google_googletest//:gtest_main",
@@ -951,7 +951,6 @@ cc_library(
951951
":litert_common",
952952
":litert_expected",
953953
":litert_tensor_buffer_types",
954-
"//litert/c:litert_common",
955954
"//litert/c/internal:litert_logging",
956955
"@com_google_absl//absl/types:span",
957956
],
@@ -984,11 +983,13 @@ cc_library(
984983
# copybara:comment_end
985984
],
986985
deps = [
986+
":litert_common",
987987
":litert_environment",
988988
":litert_event",
989989
":litert_expected",
990990
":litert_macros",
991991
":litert_ranked_tensor_type",
992+
":litert_tensor_buffer_requirements",
992993
":litert_tensor_buffer_types",
993994
"//litert/c:litert_common",
994995
"//litert/c:litert_custom_tensor_buffer",
@@ -1000,7 +1001,6 @@ cc_library(
10001001
"//litert/c:litert_tensor_buffer_types",
10011002
"//litert/c:litert_webgpu_types",
10021003
"//litert/c/internal:litert_tensor_buffer_registry", # buildcleaner: keep
1003-
"//litert/cc:litert_common",
10041004
"//litert/cc/internal:litert_handle",
10051005
"//litert/cc/internal:litert_tensor_buffer_without_registry",
10061006
"@com_google_absl//absl/cleanup",
@@ -1042,6 +1042,7 @@ cc_test(
10421042
# ":litert_model",
10431043
# ":litert_ranked_tensor_type",
10441044
# ":litert_tensor_buffer",
1045+
# ":litert_tensor_buffer_requirements",
10451046
# ":litert_tensor_buffer_types",
10461047
# "@com_google_googletest//:gtest_main",
10471048
# "@com_google_absl//absl/debugging:leak_check",

litert/cc/dynamic_runtime/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ cc_library(
469469
"//litert/cc:litert_expected",
470470
"//litert/cc:litert_macros",
471471
"//litert/cc:litert_ranked_tensor_type",
472+
"//litert/cc:litert_tensor_buffer_requirements",
472473
"//litert/cc:litert_tensor_buffer_types",
473474
"//litert/cc/internal:litert_handle",
474475
"@com_google_absl//absl/cleanup",
@@ -507,6 +508,7 @@ cc_test(
507508
"//litert/cc:litert_layout",
508509
"//litert/cc:litert_macros",
509510
"//litert/cc:litert_ranked_tensor_type",
511+
"//litert/cc:litert_tensor_buffer_requirements",
510512
"//litert/cc:litert_tensor_buffer_types",
511513
"//litert/cc/internal:litert_handle",
512514
"//litert/cc/internal:litert_platform_support",

litert/cc/internal/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ cc_library(
9595
"//litert/cc:litert_expected",
9696
"//litert/cc:litert_macros",
9797
"//litert/cc:litert_ranked_tensor_type",
98+
"//litert/cc:litert_tensor_buffer_requirements",
9899
"//litert/cc:litert_tensor_buffer_types",
99100
"@com_google_absl//absl/cleanup",
100101
"@com_google_absl//absl/log:absl_check",

litert/cc/litert_compiled_model.cc

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ Expected<TensorBufferRequirements> ToTensorBufferRequirements(
7373
LITERT_RETURN_IF_ERROR(env.runtime->GetTensorBufferRequirementsBufferSize(
7474
litert_requirements, &buffer_size));
7575

76+
size_t alignment;
77+
LITERT_RETURN_IF_ERROR(env.runtime->GetTensorBufferRequirementsAlignment(
78+
litert_requirements, &alignment));
79+
7680
int num_strides;
7781
const uint32_t* strides_ptr;
7882
LITERT_RETURN_IF_ERROR(env.runtime->GetTensorBufferRequirementsStrides(
@@ -82,10 +86,11 @@ Expected<TensorBufferRequirements> ToTensorBufferRequirements(
8286
strides.assign(strides_ptr, strides_ptr + num_strides);
8387
}
8488

85-
size_t alignment;
86-
LITERT_RETURN_IF_ERROR(env.runtime->GetTensorBufferRequirementsAlignment(
87-
litert_requirements, &alignment));
88-
89+
if (num_strides == 0 || strides[0] == 0) {
90+
// Strides are not specified.
91+
return TensorBufferRequirements::CreateWithAlignment(
92+
absl::MakeConstSpan(supported_types), buffer_size, alignment);
93+
}
8994
return TensorBufferRequirements::CreateWithAlignment(
9095
absl::MakeConstSpan(supported_types), buffer_size, alignment,
9196
absl::MakeConstSpan(strides));
@@ -347,20 +352,8 @@ Expected<size_t> CompiledModel::FindOutputIndex(
347352
Expected<TensorBuffer> CompiledModel::CreateBufferImpl(
348353
const Environment& env, const TensorBufferRequirements& buffer_requirements,
349354
const RankedTensorType& tensor_type) {
350-
LITERT_ASSIGN_OR_RETURN(const std::vector<TensorBufferType>& supported_types,
351-
buffer_requirements.SupportedTypes());
352-
if (supported_types.empty()) {
353-
return Unexpected(Status::kErrorRuntimeFailure,
354-
"Input doesn't support any tensor buffer types");
355-
}
356-
// For simplicity we just pick the first supported tensor buffer type.
357-
TensorBufferType tensor_buffer_type = supported_types[0];
358-
LITERT_ASSIGN_OR_RETURN(size_t buffer_size, buffer_requirements.BufferSize());
359-
360-
LITERT_ASSIGN_OR_RETURN(TensorBuffer buffer, TensorBuffer::CreateManaged(
361-
env, tensor_buffer_type,
362-
tensor_type, buffer_size));
363-
return buffer;
355+
return TensorBuffer::CreateManagedFromRequirements(env, tensor_type,
356+
buffer_requirements);
364357
}
365358

366359
Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(

litert/cc/litert_compiled_model_test.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,8 +1153,11 @@ TEST(CompiledModelTest, GetBufferRequirementsDetailed) {
11531153
EXPECT_GE(input_alignment, 0);
11541154

11551155
LITERT_ASSERT_OK_AND_ASSIGN(auto input_strides, input_requirements.Strides());
1156-
EXPECT_EQ(input_strides.size(), 1);
1157-
EXPECT_EQ(input_strides[0], 0);
1156+
1157+
EXPECT_LE(input_strides.size(), 1);
1158+
if (input_strides.size() == 1) {
1159+
EXPECT_EQ(input_strides[0], 0);
1160+
}
11581161

11591162
// Check output buffer requirements.
11601163
LITERT_ASSERT_OK_AND_ASSIGN(
@@ -1175,8 +1178,10 @@ TEST(CompiledModelTest, GetBufferRequirementsDetailed) {
11751178

11761179
LITERT_ASSERT_OK_AND_ASSIGN(auto output_strides,
11771180
output_requirements.Strides());
1178-
EXPECT_EQ(output_strides.size(), 1);
1179-
EXPECT_EQ(output_strides[0], 0);
1181+
EXPECT_LE(output_strides.size(), 1);
1182+
if (output_strides.size() == 1) {
1183+
EXPECT_EQ(output_strides[0], 0);
1184+
}
11801185
}
11811186

11821187
} // namespace

litert/cc/litert_tensor_buffer.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
#include "litert/cc/litert_tensor_buffer.h"
1616

1717
#include <cstddef>
18+
#include <vector>
1819

20+
#include "absl/cleanup/cleanup.h" // from @com_google_absl
1921
#include "litert/c/litert_common.h"
2022
#include "litert/c/litert_gl_types.h"
2123
#include "litert/c/litert_model_types.h"
@@ -28,10 +30,44 @@
2830
#include "litert/cc/litert_expected.h"
2931
#include "litert/cc/litert_macros.h"
3032
#include "litert/cc/litert_ranked_tensor_type.h"
33+
#include "litert/cc/litert_tensor_buffer_requirements.h"
3134
#include "litert/cc/litert_tensor_buffer_types.h"
3235

3336
namespace litert {
3437

38+
namespace {
39+
40+
// Converts a `TensorBufferRequirements` C++ object to a
41+
// `LiteRtTensorBufferRequirements` C object. In compiled_model.cc, there is a
42+
// function named `ToTensorBufferRequirements` which converts a
43+
// `LiteRtTensorBufferRequirements` C object to a `TensorBufferRequirements`
44+
// C++ object.
45+
Expected<LiteRtTensorBufferRequirements> ToLiteRtTensorBufferRequirements(
46+
const internal::EnvironmentHolder& env,
47+
const TensorBufferRequirements& requirements) {
48+
LITERT_ASSIGN_OR_RETURN(const auto supported_types,
49+
requirements.SupportedTypes());
50+
LITERT_ASSIGN_OR_RETURN(const auto buffer_size, requirements.BufferSize());
51+
LITERT_ASSIGN_OR_RETURN(const auto alignment, requirements.Alignment());
52+
LITERT_ASSIGN_OR_RETURN(const auto strides, requirements.Strides());
53+
54+
std::vector<LiteRtTensorBufferType> litert_buffer_types;
55+
litert_buffer_types.reserve(supported_types.size());
56+
for (auto type : supported_types) {
57+
litert_buffer_types.push_back(static_cast<LiteRtTensorBufferType>(type));
58+
}
59+
60+
LiteRtTensorBufferRequirements litert_requirements;
61+
LITERT_RETURN_IF_ERROR(
62+
env.runtime->CreateTensorBufferRequirementsWithAlignment(
63+
litert_buffer_types.size(), litert_buffer_types.data(), buffer_size,
64+
strides.size(), strides.data(), alignment, &litert_requirements));
65+
66+
return litert_requirements;
67+
}
68+
69+
} // namespace
70+
3571
Expected<TensorBuffer> TensorBuffer::Duplicate() const {
3672
if (!IsOwned()) {
3773
return Unexpected(Status::kErrorInvalidArgument,
@@ -53,6 +89,28 @@ Expected<TensorBuffer> TensorBuffer::CreateManaged(
5389
return TensorBuffer(env_holder, tensor_buffer, OwnHandle::kYes);
5490
}
5591

92+
Expected<TensorBuffer> TensorBuffer::CreateManagedFromRequirements(
93+
const Environment& env, const RankedTensorType& tensor_type,
94+
const TensorBufferRequirements& requirements) {
95+
LiteRtTensorBuffer tensor_buffer;
96+
auto litert_tensor_type = static_cast<LiteRtRankedTensorType>(tensor_type);
97+
auto env_holder = env.GetHolder();
98+
99+
LITERT_ASSIGN_OR_RETURN(
100+
LiteRtTensorBufferRequirements litert_requirements,
101+
ToLiteRtTensorBufferRequirements(env_holder, requirements));
102+
103+
auto cleanup = absl::MakeCleanup([&env_holder, litert_requirements] {
104+
env_holder.runtime->DestroyTensorBufferRequirements(litert_requirements);
105+
});
106+
107+
LITERT_RETURN_IF_ERROR(
108+
env_holder.runtime->CreateManagedTensorBufferFromRequirements(
109+
env_holder.handle, &litert_tensor_type, litert_requirements,
110+
&tensor_buffer));
111+
return TensorBuffer(env_holder, tensor_buffer, OwnHandle::kYes);
112+
}
113+
56114
// TODO(terryheo): make this function not depend on Environment.
57115
Expected<TensorBuffer> TensorBuffer::CreateManagedHostMemory(
58116
const RankedTensorType& tensor_type, size_t buffer_size) {

litert/cc/litert_tensor_buffer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
namespace litert {
4646

47+
class TensorBufferRequirements;
48+
4749
/// @brief A C++ wrapper for `LiteRtTensorBuffer`, representing a tensor and
4850
/// its associated backing buffer.
4951
class TensorBuffer : public internal::BaseHandle<LiteRtTensorBuffer> {
@@ -64,6 +66,14 @@ class TensorBuffer : public internal::BaseHandle<LiteRtTensorBuffer> {
6466
const Environment& env, TensorBufferType buffer_type,
6567
const RankedTensorType& tensor_type, size_t buffer_size);
6668

69+
/// @brief Creates a managed `TensorBuffer` from requirements.
70+
///
71+
/// The returned object is owned by the caller. It automatically selects
72+
/// the best buffer type and applies required alignment and padding.
73+
static Expected<TensorBuffer> CreateManagedFromRequirements(
74+
const Environment& env, const RankedTensorType& tensor_type,
75+
const TensorBufferRequirements& requirements);
76+
6777
/// @brief Creates a managed host memory `TensorBuffer` using the default
6878
/// environment (if applicable).
6979
///

litert/cc/litert_tensor_buffer_requirements.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
#include "absl/types/span.h" // from @com_google_absl
2626
#include "litert/c/internal/litert_logging.h"
27-
#include "litert/c/litert_common.h"
2827
#include "litert/cc/litert_common.h"
2928
#include "litert/cc/litert_expected.h"
3029
#include "litert/cc/litert_tensor_buffer_types.h"
@@ -84,6 +83,10 @@ class TensorBufferRequirements {
8483

8584
Expected<size_t> BufferSize() const { return buffer_size_; }
8685

86+
/// @brief Returns the strides of the tensor buffer requirements.
87+
///
88+
/// If the strides are not specified, either an empty span or a span with a
89+
/// single element of 0 is returned, which is equivalent to no strides.
8790
Expected<absl::Span<const uint32_t>> Strides() const {
8891
return absl::MakeConstSpan(strides_);
8992
}

0 commit comments

Comments
 (0)