@@ -90,6 +90,41 @@ namespace tensorflow {
9090typedef Eigen::ThreadPoolDevice CPUDevice;
9191typedef Eigen::GpuDevice GPUDevice;
9292
93+ // Maximum tensor size (in bytes) that cuDNN can handle safely.
94+ // cuDNN has internal limits around 2GB for certain operations.
95+ // We use a conservative threshold to avoid CUDA invalid resource handle errors.
96+ constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024 ; // 2GB
97+
98+ // Helper function to check if the tensor size exceeds the safe limit for cuDNN.
99+ // Returns true if the tensor is too large and needs fallback processing.
100+ template <typename T>
101+ inline bool IsTensorTooLargeForCudnn (const Tensor& tensor) {
102+ int64_t tensor_size_bytes = tensor.NumElements () * sizeof (T);
103+ return tensor_size_bytes > kMaxCudnnTensorSizeBytes ;
104+ }
105+
106+ // Helper function to compute the maximum batch size that keeps the tensor
107+ // under the cuDNN size limit.
108+ template <typename T>
109+ inline int64_t ComputeSafeBatchSize (const Tensor& tensor, int64_t current_batch,
110+ TensorFormat data_format) {
111+ if (current_batch <= 0 ) return 1 ;
112+ int64_t total_elements = tensor.NumElements ();
113+ if (total_elements <= 0 ) return 1 ;
114+ // Handle edge case where total_elements < current_batch
115+ if (total_elements < current_batch) {
116+ // Each batch has less than 1 element on average, return 1
117+ return 1 ;
118+ }
119+ int64_t elements_per_batch = total_elements / current_batch;
120+ if (elements_per_batch <= 0 ) return 1 ;
121+ int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof (T);
122+ int64_t safe_batch = max_elements / elements_per_batch;
123+ // Ensure at least batch size of 1, and cap at current batch size
124+ return std::max (static_cast <int64_t >(1 ),
125+ std::min (safe_batch, current_batch));
126+ }
127+
93128template <typename Device, typename T>
94129struct LaunchGeneric {
95130 void operator ()(OpKernelContext* ctx, const Tensor& input,
@@ -773,6 +808,123 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune,
773808 absl::InvalidArgumentError (" filter must not have zero elements "
774809 " (i.e. all dimensions must be non-zero)" ));
775810
811+ // Check if input tensor is too large for cuDNN and needs batch splitting.
812+ // This addresses CUDA invalid resource handle errors with large tensors.
813+ if (IsTensorTooLargeForCudnn<T>(input) && in_batch > 1 ) {
814+ int64_t safe_batch = ComputeSafeBatchSize<T>(input, in_batch, data_format);
815+ if (safe_batch < in_batch && safe_batch > 0 ) {
816+ VLOG (2 ) << " Input tensor too large for cuDNN, splitting batch from "
817+ << in_batch << " to chunks of " << safe_batch;
818+
819+ // Process in batches to avoid cuDNN memory limits
820+ int64_t batch_idx = GetTensorDimIndex (data_format, ' N' , input.dims ());
821+
822+ // Validate batch dimension before proceeding
823+ OP_REQUIRES (context, batch_idx >= 0 && batch_idx < input.dims (),
824+ absl::InternalError (" Invalid batch dimension index" ));
825+ OP_REQUIRES (context, input.dim_size (batch_idx) > 0 ,
826+ absl::InternalError (" Input batch dimension is zero" ));
827+ OP_REQUIRES (context, output->dim_size (batch_idx) > 0 ,
828+ absl::InternalError (" Output batch dimension is zero" ));
829+
830+ for (int64_t start = 0 ; start < in_batch; start += safe_batch) {
831+ int64_t chunk_size = std::min (safe_batch, in_batch - start);
832+
833+ // Create sliced input tensor
834+ std::vector<int64_t > input_slice_shape;
835+ for (int i = 0 ; i < input.dims (); ++i) {
836+ if (i == batch_idx) {
837+ input_slice_shape.push_back (chunk_size);
838+ } else {
839+ input_slice_shape.push_back (input.dim_size (i));
840+ }
841+ }
842+ TensorShape input_slice_ts (input_slice_shape);
843+ Tensor input_slice;
844+ OP_REQUIRES_OK (context, context->allocate_temp (DataTypeToEnum<T>::value,
845+ input_slice_ts,
846+ &input_slice));
847+
848+ // Create sliced output tensor
849+ std::vector<int64_t > output_slice_shape;
850+ for (int i = 0 ; i < output->dims (); ++i) {
851+ if (i == batch_idx) {
852+ output_slice_shape.push_back (chunk_size);
853+ } else {
854+ output_slice_shape.push_back (output->dim_size (i));
855+ }
856+ }
857+ TensorShape output_slice_ts (output_slice_shape);
858+ Tensor output_slice;
859+ OP_REQUIRES_OK (context, context->allocate_temp (DataTypeToEnum<T>::value,
860+ output_slice_ts,
861+ &output_slice));
862+
863+ // Calculate elements per batch with validated dimensions
864+ int64_t input_batch_dim = input.dim_size (batch_idx);
865+ int64_t elements_per_batch = input.NumElements () / input_batch_dim;
866+
867+ // Validate bounds before pointer arithmetic
868+ int64_t input_offset = start * elements_per_batch;
869+ OP_REQUIRES (context, input_offset + chunk_size * elements_per_batch <=
870+ input.NumElements (),
871+ absl::InternalError (" Input slice bounds check failed" ));
872+
873+ // Copy input slice from input tensor (device to device)
874+ int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof (T);
875+ auto src_ptr = se::DeviceMemoryBase (
876+ const_cast <T*>(input.template flat <T>().data () + input_offset),
877+ copy_size_bytes);
878+ auto dst_ptr = se::DeviceMemoryBase (
879+ const_cast <T*>(input_slice.template flat <T>().data ()),
880+ copy_size_bytes);
881+ OP_REQUIRES_OK (context,
882+ stream->MemcpyD2D (&dst_ptr, src_ptr, copy_size_bytes));
883+
884+ // Recursively call LaunchConvOpImpl with the smaller batch.
885+ // Safety note: The recursive call is guaranteed not to re-enter this
886+ // batch-splitting code path because:
887+ // 1. safe_batch is computed to keep sliced tensors under the size limit
888+ // 2. IsTensorTooLargeForCudnn will return false for the sliced tensor
889+ // 3. Even if it were to trigger, in_batch would equal chunk_size,
890+ // and safe_batch would equal chunk_size, so the condition
891+ // "safe_batch < in_batch" would be false
892+ LaunchConvOpImpl<T>(context, cudnn_use_autotune, input_slice, filter,
893+ dilations, strides, padding, explicit_paddings,
894+ data_format, &output_slice);
895+
896+ // Check for errors from recursive call
897+ if (!context->status ().ok ()) return ;
898+
899+ // Calculate output elements per batch with validated dimensions
900+ int64_t output_batch_dim = output->dim_size (batch_idx);
901+ int64_t output_elements_per_batch =
902+ output->NumElements () / output_batch_dim;
903+
904+ // Validate bounds before pointer arithmetic
905+ int64_t output_offset = start * output_elements_per_batch;
906+ OP_REQUIRES (
907+ context,
908+ output_offset + chunk_size * output_elements_per_batch <=
909+ output->NumElements (),
910+ absl::InternalError (" Output slice bounds check failed" ));
911+
912+ // Copy output slice to output tensor (device to device)
913+ int64_t output_copy_size_bytes =
914+ chunk_size * output_elements_per_batch * sizeof (T);
915+ auto out_src_ptr = se::DeviceMemoryBase (
916+ const_cast <T*>(output_slice.template flat <T>().data ()),
917+ output_copy_size_bytes);
918+ auto out_dst_ptr = se::DeviceMemoryBase (
919+ const_cast <T*>(output->template flat <T>().data () + output_offset),
920+ output_copy_size_bytes);
921+ OP_REQUIRES_OK (context, stream->MemcpyD2D (&out_dst_ptr, out_src_ptr,
922+ output_copy_size_bytes));
923+ }
924+ return ;
925+ }
926+ }
927+
776928 bool is_grouped_convolution = filter_depth != in_depth;
777929 // check if filter is 1x1 and stride/dilation are all ones
778930 bool one_filter = true ;
0 commit comments