@@ -341,10 +341,15 @@ int main(int argc, char** argv) {
341341 " Failed to allocate custom memory. tensor index: %d, bytes: %zu" ,
342342 input_index,
343343 tensor_meta->nbytes ());
344+ ssize_t numel = 1 ;
345+ for (size_t i = 0 ; i < tensor_meta->sizes ().size (); i++) {
346+ numel *= tensor_meta->sizes ()[i];
347+ }
344348 TensorImpl impl = TensorImpl (
345349 tensor_meta->scalar_type (),
346350 /* dim=*/ tensor_meta->sizes ().size (),
347351 const_cast <TensorImpl::SizesType*>(tensor_meta->sizes ().data ()),
352+ numel,
348353 custom_mem_ptr->GetPtr (),
349354 const_cast <TensorImpl::DimOrderType*>(tensor_meta->dim_order ().data ()));
350355 Error ret = method->set_input (Tensor (&impl), input_index);
@@ -475,13 +480,19 @@ int main(int argc, char** argv) {
475480 // For pre-allocated use case, we need to call set_input
476481 // to copy data for the input tensors since they doesn't
477482 // share the data with in_custom_mem.
483+ const auto * sizes_ptr = expected_input_shapes.empty ()
484+ ? tensor_meta->sizes ().data ()
485+ : expected_input_shapes[input_index].data ();
486+ ssize_t dim = tensor_meta->sizes ().size ();
487+ ssize_t numel = 1 ;
488+ for (ssize_t i = 0 ; i < dim; i++) {
489+ numel *= sizes_ptr[i];
490+ }
478491 TensorImpl impl = TensorImpl (
479492 tensor_meta->scalar_type (),
480- /* dim=*/ tensor_meta->sizes ().size (),
481- const_cast <TensorImpl::SizesType*>(
482- expected_input_shapes.empty ()
483- ? tensor_meta->sizes ().data ()
484- : expected_input_shapes[input_index].data ()),
493+ /* dim=*/ dim,
494+ const_cast <TensorImpl::SizesType*>(sizes_ptr),
495+ numel,
485496 in_custom_mem[input_index]->GetPtr (),
486497 const_cast <TensorImpl::DimOrderType*>(
487498 tensor_meta->dim_order ().data ()));
0 commit comments