@@ -319,24 +319,9 @@ Error defineTensor(
319319 ET_CHECK_OR_RETURN_ERROR (
320320 tensor_value != nullptr , InvalidProgram, " Deserialized tensor is null" );
321321
322- ET_CHECK_OR_RETURN_ERROR (
323- tensor_value->num_dims () == 0 || tensor_value->dims () != nullptr ,
324- InvalidProgram,
325- " Tensor dims is null but num_dims is %u" ,
326- tensor_value->num_dims ());
327-
328- if (tensor_value->dims () != nullptr ) {
329- ET_CHECK_OR_RETURN_ERROR (
330- tensor_value->num_dims () == tensor_value->dims ()->size (),
331- InvalidProgram,
332- " Tensor num_dims %u does not match dims array size %u" ,
333- tensor_value->num_dims (),
334- tensor_value->dims ()->size ());
335- }
336-
337- // Get tensor dims, here we need to use a vector in order
338- // to properly convert the uint32_t* to size_t*. For scalar tensors
339- // (num_dims == 0), dims() is permitted to be null per the check above.
322+ // Get tensor dims, here we need to use a vector in order to properly
323+ // convert the uint32_t* to size_t*. Scalar tensors (rank 0) are permitted
324+ // to have a null dims vector; in that case dims_data is empty.
340325 std::vector<size_t > dims_data;
341326 if (tensor_value->dims () != nullptr ) {
342327 dims_data = flatbufferDimsToVector (tensor_value->dims ());
@@ -386,7 +371,7 @@ Error defineTensor(
386371 status = xnn_define_tensor_value (
387372 /* subgraph=*/ subgraph_ptr,
388373 /* datatype=*/ getDataType (tensor_value->datatype ()),
389- /* num_dims=*/ tensor_value-> num_dims (),
374+ /* num_dims=*/ dims_data. size (),
390375 /* dims=*/ dims_data.data (),
391376 /* data=*/ buffer_ptr,
392377 /* external_id=*/ tensor_value->external_id (),
@@ -421,7 +406,7 @@ Error defineTensor(
421406 status = xnn_define_dynamically_quantized_tensor_value (
422407 /* subgraph=*/ subgraph_ptr,
423408 /* datatype=*/ xnn_datatype_qdint8,
424- /* num_dims=*/ tensor_value-> num_dims (),
409+ /* num_dims=*/ dims_data. size (),
425410 /* num_nonbatch_dims=*/ 1 , // always do per token quantization
426411 /* dims=*/ dims_data.data (),
427412 /* external_id=*/ XNN_INVALID_VALUE_ID, // always internal value id
@@ -435,7 +420,7 @@ Error defineTensor(
435420 status = xnn_define_tensor_value (
436421 /* subgraph=*/ subgraph_ptr,
437422 /* datatype=*/ fp_datatype,
438- /* num_dims=*/ tensor_value-> num_dims (),
423+ /* num_dims=*/ dims_data. size (),
439424 /* dims=*/ dims_data.data (),
440425 /* data=*/ buffer_ptr,
441426 /* external_id=*/ tensor_value->external_id (),
@@ -476,7 +461,7 @@ Error defineTensor(
476461 /* datatype=*/ getDataType (tensor_value->datatype ()),
477462 /* zero_point=*/ qparams->zero_point (),
478463 /* scale=*/ qparams->scale (),
479- /* num_dims=*/ tensor_value-> num_dims (),
464+ /* num_dims=*/ dims_data. size (),
480465 /* dims=*/ dims_data.data (),
481466 /* data=*/ buffer_ptr,
482467 /* external_id=*/ tensor_value->external_id (),
@@ -521,7 +506,7 @@ Error defineTensor(
521506 /* datatype=*/ dtype,
522507 /* zero_point=*/ zero_point,
523508 /* scale=*/ scale,
524- /* num_dims=*/ tensor_value-> num_dims (),
509+ /* num_dims=*/ dims_data. size (),
525510 /* channel_dim*/ qparams->channel_dim (),
526511 /* dims=*/ dims_data.data (),
527512 /* data=*/ buffer_ptr,
@@ -599,7 +584,7 @@ Error defineTensor(
599584 /* datatype=*/ datatype,
600585 /* zero_point=*/ zero_point,
601586 /* scale=*/ scale_data,
602- /* num_dims=*/ tensor_value-> num_dims (),
587+ /* num_dims=*/ dims_data. size (),
603588 /* channel_dim=*/ qparams->channel_dim (),
604589 /* block_size=*/ qparams->group_size (),
605590 /* dims=*/ dims_data.data (),
@@ -613,8 +598,8 @@ Error defineTensor(
613598 auto qparams = qtensor_value->quant_params_as_PerTokenDynamicQuant ();
614599 ET_LOG (
615600 Debug,
616- " define quant tensor (dynamic): num_dims: %i , num_nonbatch_dims: %i\n " ,
617- tensor_value-> num_dims (),
601+ " define quant tensor (dynamic): num_dims: %zu , num_nonbatch_dims: %i\n " ,
602+ dims_data. size (),
618603 qparams->num_nonbatch_dims ());
619604 ET_CHECK_OR_RETURN_ERROR (
620605 buffer_ptr == nullptr ,
@@ -623,7 +608,7 @@ Error defineTensor(
623608 status = xnn_define_dynamically_quantized_tensor_value (
624609 /* subgraph=*/ subgraph_ptr,
625610 /* datatype=*/ getDataType (tensor_value->datatype ()),
626- /* num_dims=*/ tensor_value-> num_dims (),
611+ /* num_dims=*/ dims_data. size (),
627612 /* num_nonbatch_dims*/ qparams->num_nonbatch_dims (),
628613 /* dims=*/ dims_data.data (),
629614 /* external_id=*/ tensor_value->external_id (),
0 commit comments