33
44#include " core/providers/cpu/rnn/rnn.h"
55
6+ #include " core/common/narrow.h"
67#include " core/common/safeint.h"
78#include " core/framework/op_kernel_context_internal.h"
89#include " core/providers/cpu/rnn/rnn_activation_functors.h"
@@ -84,15 +85,32 @@ void ApplyActivationToBatches(const Tensor* sequence_lens, const T* h_prev, T* Y
8485template <typename T>
8586void Assign_Y_h (const T* Y_buffer_data, Tensor* Y_h, const Tensor* sequence_lens,
8687 int64_t num_directions, int direction, bool isReverse, int64_t batch_size, int64_t seq_length, int64_t hidden_size) {
88+ if (seq_length == 0 ) {
89+ // No sequence data was processed; zero out Y_h for this direction.
90+ const size_t y_h_direction_size = SafeMul<size_t >(batch_size, hidden_size);
91+ const size_t Y_h_direction_offset = SafeMul<size_t >(direction, y_h_direction_size);
92+ math::Set<T, CPUMathUtil>(y_h_direction_size, T{0 },
93+ Y_h->MutableData <T>() + Y_h_direction_offset, &CPUMathUtil::Instance ());
94+ return ;
95+ }
96+
8797 for (int batch = 0 ; batch < batch_size; batch++) {
8898 int64_t last_time_step = isReverse ? 0 : seq_length - 1 ;
89- if (nullptr != sequence_lens && !isReverse)
99+ if (nullptr != sequence_lens && !isReverse) {
90100 last_time_step = sequence_lens->Data <int >()[batch] - 1 ;
101+ if (last_time_step < 0 ) {
102+ // sequence_lens[batch] == 0: no data was processed for this batch; zero out Y_h.
103+ int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size;
104+ math::Set<T, CPUMathUtil>(narrow<size_t >(hidden_size), T{0 },
105+ Y_h->MutableData <T>() + Y_h_offset, &CPUMathUtil::Instance ());
106+ continue ;
107+ }
108+ }
91109 int64_t y_offset = last_time_step * num_directions * batch_size * hidden_size +
92110 direction * batch_size * hidden_size +
93111 batch * hidden_size;
94112 int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size;
95- math::CopyVector<T, CPUMathUtil>(static_cast <int >(hidden_size), Y_buffer_data + y_offset,
113+ math::CopyVector<T, CPUMathUtil>(narrow <int >(hidden_size), Y_buffer_data + y_offset,
96114 Y_h->MutableData <T>() + Y_h_offset,
97115 &CPUMathUtil::Instance ());
98116 }
@@ -109,7 +127,7 @@ void ClearMissingFrames(T* Y_buffer_data, const Tensor* sequence_lens,
109127 seq * num_directions * batch_size * hidden_size +
110128 direction * batch_size * hidden_size +
111129 batch * hidden_size;
112- math::Set<T, CPUMathUtil>(onnxruntime:: narrow<size_t >(hidden_size), 0 , Y_buffer_data + offset, &CPUMathUtil::Instance ());
130+ math::Set<T, CPUMathUtil>(narrow<size_t >(hidden_size), 0 , Y_buffer_data + offset, &CPUMathUtil::Instance ());
113131 }
114132 }
115133 }
@@ -155,7 +173,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
155173 ORT_RETURN_IF_ERROR (ctx->GetTempSpaceAllocator (&alloc));
156174
157175 // X * W^t, each direction has shape of [seq_length, batch_size, hidden_size]
158- auto x_matmul_data = alloc->Alloc (SafeInt <size_t >(sizeof (float )) * seq_length * batch_size * hidden_size_);
176+ auto x_matmul_data = alloc->Alloc (SafeMul <size_t >(sizeof (float ), seq_length, batch_size, hidden_size_) );
159177 BufferUniquePtr x_matmul_buffer (x_matmul_data, BufferDeleter (alloc));
160178 auto * x_matmul_w_buffer_data = static_cast <float *>(x_matmul_buffer.get ());
161179
@@ -165,7 +183,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
165183 if (Y != nullptr )
166184 Y_buffer_data = Y->MutableData <float >();
167185 else {
168- Y_data = alloc->Alloc (SafeInt <size_t >(sizeof (float )) * seq_length * num_directions * batch_size * hidden_size_);
186+ Y_data = alloc->Alloc (SafeMul <size_t >(sizeof (float ), seq_length, num_directions, batch_size, hidden_size_) );
169187 Y_matmul_buffer = BufferUniquePtr (Y_data, BufferDeleter (alloc));
170188 Y_buffer_data = static_cast <float *>(Y_matmul_buffer.get ());
171189 }
@@ -177,20 +195,20 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
177195 bool isReverse = direction_ == " reverse" || direction == 1 ;
178196
179197 if (B != nullptr ) {
180- EigenMatrixMapRowMajor<float >(x_matmul_w_buffer_data, seq_length * SafeInt <size_t >(batch_size), onnxruntime:: narrow<size_t >(hidden_size_)).rowwise () =
181- ConstEigenVectorMap<float >(B->Data <float >() + direction * 2 * hidden_size_, onnxruntime:: narrow<size_t >(hidden_size_)).transpose () +
182- ConstEigenVectorMap<float >(B->Data <float >() + direction * 2 * hidden_size_ + hidden_size_, onnxruntime:: narrow<size_t >(hidden_size_)).transpose ();
198+ EigenMatrixMapRowMajor<float >(x_matmul_w_buffer_data, SafeMul <size_t >(seq_length, batch_size), narrow<size_t >(hidden_size_)).rowwise () =
199+ ConstEigenVectorMap<float >(B->Data <float >() + direction * 2 * hidden_size_, narrow<size_t >(hidden_size_)).transpose () +
200+ ConstEigenVectorMap<float >(B->Data <float >() + direction * 2 * hidden_size_ + hidden_size_, narrow<size_t >(hidden_size_)).transpose ();
183201 } else {
184- math::Set<float , CPUMathUtil>(seq_length * batch_size * SafeInt <size_t >(hidden_size_), 0 , x_matmul_w_buffer_data, &CPUMathUtil::Instance ());
202+ math::Set<float , CPUMathUtil>(SafeMul <size_t >(seq_length, batch_size, hidden_size_), 0 , x_matmul_w_buffer_data, &CPUMathUtil::Instance ());
185203 }
186204
187205 // X * W[direction]^t + B
188206 math::Gemm<float >(
189207 CblasNoTrans,
190208 CblasTrans,
191- static_cast <int >(seq_length * batch_size),
192- static_cast <int >(hidden_size_),
193- static_cast <int >(input_size),
209+ SafeMul <int >(seq_length, batch_size),
210+ narrow <int >(hidden_size_),
211+ narrow <int >(input_size),
194212 1 ,
195213 X.Data <float >(),
196214 W.Data <float >() + direction * hidden_size_ * input_size,
@@ -202,7 +220,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
202220 int64_t time_step = isReverse ? (seq_length - t - 1 ) : t;
203221 int64_t Y_frame_offset = (time_step * num_directions + direction) * Y_frame_size;
204222 float * Y_buffer_data_current_frame = Y_buffer_data + Y_frame_offset;
205- auto y_frame_mat = EigenMatrixMapRowMajor<float >(Y_buffer_data_current_frame, onnxruntime:: narrow<size_t >(batch_size), onnxruntime:: narrow<size_t >(hidden_size_));
223+ auto y_frame_mat = EigenMatrixMapRowMajor<float >(Y_buffer_data_current_frame, narrow<size_t >(batch_size), narrow<size_t >(hidden_size_));
206224
207225 const float * h_prev = nullptr ;
208226 if (t == 0 ) {
@@ -224,21 +242,21 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
224242 math::Gemm<float >(
225243 CblasNoTrans,
226244 CblasTrans,
227- static_cast <int >(batch_size),
228- static_cast <int >(hidden_size_),
229- static_cast <int >(hidden_size_),
245+ narrow <int >(batch_size),
246+ narrow <int >(hidden_size_),
247+ narrow <int >(hidden_size_),
230248 1 ,
231249 h_prev,
232250 R.Data <float >() + direction * hidden_size_ * hidden_size_,
233251 0 ,
234252 Y_buffer_data_current_frame,
235253 tp, &mlas_backend_kernel_selector_config_);
236254 } else {
237- math::Set<float , CPUMathUtil>(batch_size * SafeInt <size_t >(hidden_size_), 0 , Y_buffer_data_current_frame, &CPUMathUtil::Instance ());
255+ math::Set<float , CPUMathUtil>(SafeMul <size_t >(batch_size, hidden_size_), 0 , Y_buffer_data_current_frame, &CPUMathUtil::Instance ());
238256 }
239257
240258 // X[time_step] * W^t + H_t_1 * R^t
241- y_frame_mat += EigenMatrixMapRowMajor<float >(&x_matmul_w_buffer_data[time_step * Y_frame_size], onnxruntime:: narrow<size_t >(batch_size), onnxruntime:: narrow<size_t >(hidden_size_));
259+ y_frame_mat += EigenMatrixMapRowMajor<float >(&x_matmul_w_buffer_data[time_step * Y_frame_size], narrow<size_t >(batch_size), narrow<size_t >(hidden_size_));
242260
243261 // apply activation
244262 ApplyActivationToBatches<float >(sequence_lens, h_prev, Y_buffer_data_current_frame,
@@ -258,10 +276,10 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
258276 }
259277
260278 if (Y != nullptr )
261- DumpMatrix (" Y" , Y_buffer_data, ( int ) (seq_length * num_directions * batch_size), ( int ) hidden_size_);
279+ DumpMatrix (" Y" , Y_buffer_data, SafeMul< int > (seq_length, num_directions, batch_size), narrow< int >( hidden_size_) );
262280
263281 if (Y_h != nullptr )
264- DumpMatrix (" Y_h" , Y_h->Data <float >(), ( int ) (num_directions * batch_size), ( int ) hidden_size_);
282+ DumpMatrix (" Y_h" , Y_h->Data <float >(), SafeMul< int > (num_directions, batch_size), narrow< int >( hidden_size_) );
265283
266284 return Status::OK ();
267285}
0 commit comments