11#include < libtorchaudio/rnnt/cpu/cpu_transducer.h>
2+ #include < libtorchaudio/utils.h>
3+
24#include < torch/csrc/stable/library.h>
35#include < torch/csrc/stable/ops.h>
46#include < torch/csrc/stable/tensor.h>
@@ -73,15 +75,11 @@ std::tuple<Tensor, Tensor> compute(
7375 STD_TORCH_CHECK (
7476 blank >= 0 && blank < logits.size (-1 ),
7577 " blank must be within [0, logits.shape[-1])" );
76-
77- auto max_ivalue = [](const Tensor& t) {
78- return reinterpret_cast <int32_t *>(torch::stable::amax (t, {}).data_ptr ())[0 ];
79- };
80-
8178 STD_TORCH_CHECK (
82- logits.size (1 ) == max_ivalue (logit_lengths), " input length mismatch" );
79+ logits.size (1 ) == torchaudio::util::max<int64_t >(logit_lengths),
80+ " input length mismatch" );
8381 STD_TORCH_CHECK (
84- logits.size (2 ) == max_ivalue (target_lengths) + 1 ,
82+ logits.size (2 ) == torchaudio::util::max< int64_t > (target_lengths) + 1 ,
8583 " output length mismatch" );
8684 STD_TORCH_CHECK (
8785 targets.size (1 ) + 1 == logits.size (2 ), " target length mismatch" );
@@ -110,14 +108,12 @@ std::tuple<Tensor, Tensor> compute(
110108 {DtypeWorkspace<float >::ComputeSizeFromOptions (options)},
111109 ScalarType::Float);
112110
113- // TODO: use t.mutable_data_ptr<..>() instead of reinterpret_cast
114- // when stable ABI Tensor supports mutable_data_ptr templates.
115111 Workspace<float > workspace (
116112 /* options=*/ options,
117113 /* dtype_data=*/
118- reinterpret_cast <float *>(float_workspace. mutable_data_ptr () ),
114+ float_workspace. mutable_data_ptr <float >( ),
119115 /* dtype_size=*/ float_workspace.numel (),
120- /* int_data=*/ reinterpret_cast <int *>(int_workspace. mutable_data_ptr () ),
116+ /* int_data=*/ int_workspace. mutable_data_ptr <int >( ),
121117 /* int_size=*/ int_workspace.numel ());
122118
123119 THO_DISPATCH_V2 (
@@ -126,12 +122,12 @@ std::tuple<Tensor, Tensor> compute(
126122 AT_WRAP ([&] {
127123 (Compute</* DTYPE=*/ scalar_t , /* CAST_DTYPE=*/ float >(
128124 /* workspace=*/ workspace,
129- /* logits=*/ reinterpret_cast <scalar_t *>(logits. data_ptr () ),
130- /* targets=*/ reinterpret_cast <int *>(targets. data_ptr () ),
131- /* srcLengths=*/ reinterpret_cast <int *>(logit_lengths. data_ptr () ),
132- /* tgtLengths=*/ reinterpret_cast <int *>(target_lengths. data_ptr () ),
133- /* costs=*/ reinterpret_cast <scalar_t *>(costs. data_ptr () ),
134- /* gradients=*/ reinterpret_cast <scalar_t *>(gradients. data_ptr () )));
125+ /* logits=*/ logits. const_data_ptr <scalar_t >( ),
126+ /* targets=*/ targets. const_data_ptr <int >( ),
127+ /* srcLengths=*/ logit_lengths. const_data_ptr <int >( ),
128+ /* tgtLengths=*/ target_lengths. const_data_ptr <int >( ),
129+ /* costs=*/ costs. mutable_data_ptr <scalar_t >( ),
130+ /* gradients=*/ gradients. mutable_data_ptr <scalar_t >( )));
135131 }),
136132 ScalarType::Float,
137133 ScalarType::Half);
0 commit comments