Skip to content

Commit 3168986

Browse files
committed
chore: update TS types, refactor getTensorShape
1 parent 5560f47 commit 3168986

2 files changed

Lines changed: 3 additions & 8 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,8 @@ void BaseModel::unload() { module.reset(nullptr); }
155155

156156
std::vector<int32_t>
157157
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
158-
std::vector<int32_t> tensorShape;
159158
auto sizes = tensor.sizes();
160-
tensorShape.reserve(sizes.size());
161-
for (auto size : sizes) {
162-
tensorShape.push_back(static_cast<int32_t>(size));
163-
}
164-
return tensorShape;
159+
return std::vector<int32_t>(sizes.begin(), sizes.end());
165160
}
166161

167162
} // namespace rnexecutorch

packages/react-native-executorch/src/types/common.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export type TensorBuffer =
5555
| BigUint64Array;
5656

5757
export interface TensorPtr {
58-
data: TensorBuffer;
59-
shape: number[];
58+
dataPtr: TensorBuffer;
59+
sizes: number[];
6060
scalarType: ScalarType;
6161
}

0 commit comments

Comments
 (0)