@@ -90,41 +90,6 @@ namespace tensorflow {
9090typedef Eigen::ThreadPoolDevice CPUDevice;
9191typedef Eigen::GpuDevice GPUDevice;
9292
93- // Maximum tensor size (in bytes) that cuDNN can handle safely.
94- // cuDNN has internal limits around 2GB for certain operations.
95- // We use a conservative threshold to avoid CUDA invalid resource handle errors.
96- constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024 ; // 2GB
97-
98- // Helper function to check if the tensor size exceeds the safe limit for cuDNN.
99- // Returns true if the tensor is too large and needs fallback processing.
100- template <typename T>
101- inline bool IsTensorTooLargeForCudnn (const Tensor& tensor) {
102- int64_t tensor_size_bytes = tensor.NumElements () * sizeof (T);
103- return tensor_size_bytes > kMaxCudnnTensorSizeBytes ;
104- }
105-
106- // Helper function to compute the maximum batch size that keeps the tensor
107- // under the cuDNN size limit.
108- template <typename T>
109- inline int64_t ComputeSafeBatchSize (const Tensor& tensor, int64_t current_batch,
110- TensorFormat data_format) {
111- if (current_batch <= 0 ) return 1 ;
112- int64_t total_elements = tensor.NumElements ();
113- if (total_elements <= 0 ) return 1 ;
114- // Handle edge case where total_elements < current_batch
115- if (total_elements < current_batch) {
116- // Each batch has less than 1 element on average, return 1
117- return 1 ;
118- }
119- int64_t elements_per_batch = total_elements / current_batch;
120- if (elements_per_batch <= 0 ) return 1 ;
121- int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof (T);
122- int64_t safe_batch = max_elements / elements_per_batch;
123- // Ensure at least batch size of 1, and cap at current batch size
124- return std::max (static_cast <int64_t >(1 ),
125- std::min (safe_batch, current_batch));
126- }
127-
12893template <typename Device, typename T>
12994struct LaunchGeneric {
13095 void operator ()(OpKernelContext* ctx, const Tensor& input,
@@ -808,123 +773,6 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune,
808773 absl::InvalidArgumentError (" filter must not have zero elements "
809774 " (i.e. all dimensions must be non-zero)" ));
810775
811- // Check if input tensor is too large for cuDNN and needs batch splitting.
812- // This addresses CUDA invalid resource handle errors with large tensors.
813- if (IsTensorTooLargeForCudnn<T>(input) && in_batch > 1 ) {
814- int64_t safe_batch = ComputeSafeBatchSize<T>(input, in_batch, data_format);
815- if (safe_batch < in_batch && safe_batch > 0 ) {
816- VLOG (2 ) << " Input tensor too large for cuDNN, splitting batch from "
817- << in_batch << " to chunks of " << safe_batch;
818-
819- // Process in batches to avoid cuDNN memory limits
820- int64_t batch_idx = GetTensorDimIndex (data_format, ' N' , input.dims ());
821-
822- // Validate batch dimension before proceeding
823- OP_REQUIRES (context, batch_idx >= 0 && batch_idx < input.dims (),
824- absl::InternalError (" Invalid batch dimension index" ));
825- OP_REQUIRES (context, input.dim_size (batch_idx) > 0 ,
826- absl::InternalError (" Input batch dimension is zero" ));
827- OP_REQUIRES (context, output->dim_size (batch_idx) > 0 ,
828- absl::InternalError (" Output batch dimension is zero" ));
829-
830- for (int64_t start = 0 ; start < in_batch; start += safe_batch) {
831- int64_t chunk_size = std::min (safe_batch, in_batch - start);
832-
833- // Create sliced input tensor
834- std::vector<int64_t > input_slice_shape;
835- for (int i = 0 ; i < input.dims (); ++i) {
836- if (i == batch_idx) {
837- input_slice_shape.push_back (chunk_size);
838- } else {
839- input_slice_shape.push_back (input.dim_size (i));
840- }
841- }
842- TensorShape input_slice_ts (input_slice_shape);
843- Tensor input_slice;
844- OP_REQUIRES_OK (context, context->allocate_temp (DataTypeToEnum<T>::value,
845- input_slice_ts,
846- &input_slice));
847-
848- // Create sliced output tensor
849- std::vector<int64_t > output_slice_shape;
850- for (int i = 0 ; i < output->dims (); ++i) {
851- if (i == batch_idx) {
852- output_slice_shape.push_back (chunk_size);
853- } else {
854- output_slice_shape.push_back (output->dim_size (i));
855- }
856- }
857- TensorShape output_slice_ts (output_slice_shape);
858- Tensor output_slice;
859- OP_REQUIRES_OK (context, context->allocate_temp (DataTypeToEnum<T>::value,
860- output_slice_ts,
861- &output_slice));
862-
863- // Calculate elements per batch with validated dimensions
864- int64_t input_batch_dim = input.dim_size (batch_idx);
865- int64_t elements_per_batch = input.NumElements () / input_batch_dim;
866-
867- // Validate bounds before pointer arithmetic
868- int64_t input_offset = start * elements_per_batch;
869- OP_REQUIRES (context, input_offset + chunk_size * elements_per_batch <=
870- input.NumElements (),
871- absl::InternalError (" Input slice bounds check failed" ));
872-
873- // Copy input slice from input tensor (device to device)
874- int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof (T);
875- auto src_ptr = se::DeviceMemoryBase (
876- const_cast <T*>(input.template flat <T>().data () + input_offset),
877- copy_size_bytes);
878- auto dst_ptr = se::DeviceMemoryBase (
879- const_cast <T*>(input_slice.template flat <T>().data ()),
880- copy_size_bytes);
881- OP_REQUIRES_OK (context,
882- stream->MemcpyD2D (&dst_ptr, src_ptr, copy_size_bytes));
883-
884- // Recursively call LaunchConvOpImpl with the smaller batch.
885- // Safety note: The recursive call is guaranteed not to re-enter this
886- // batch-splitting code path because:
887- // 1. safe_batch is computed to keep sliced tensors under the size limit
888- // 2. IsTensorTooLargeForCudnn will return false for the sliced tensor
889- // 3. Even if it were to trigger, in_batch would equal chunk_size,
890- // and safe_batch would equal chunk_size, so the condition
891- // "safe_batch < in_batch" would be false
892- LaunchConvOpImpl<T>(context, cudnn_use_autotune, input_slice, filter,
893- dilations, strides, padding, explicit_paddings,
894- data_format, &output_slice);
895-
896- // Check for errors from recursive call
897- if (!context->status ().ok ()) return ;
898-
899- // Calculate output elements per batch with validated dimensions
900- int64_t output_batch_dim = output->dim_size (batch_idx);
901- int64_t output_elements_per_batch =
902- output->NumElements () / output_batch_dim;
903-
904- // Validate bounds before pointer arithmetic
905- int64_t output_offset = start * output_elements_per_batch;
906- OP_REQUIRES (
907- context,
908- output_offset + chunk_size * output_elements_per_batch <=
909- output->NumElements (),
910- absl::InternalError (" Output slice bounds check failed" ));
911-
912- // Copy output slice to output tensor (device to device)
913- int64_t output_copy_size_bytes =
914- chunk_size * output_elements_per_batch * sizeof (T);
915- auto out_src_ptr = se::DeviceMemoryBase (
916- const_cast <T*>(output_slice.template flat <T>().data ()),
917- output_copy_size_bytes);
918- auto out_dst_ptr = se::DeviceMemoryBase (
919- const_cast <T*>(output->template flat <T>().data () + output_offset),
920- output_copy_size_bytes);
921- OP_REQUIRES_OK (context, stream->MemcpyD2D (&out_dst_ptr, out_src_ptr,
922- output_copy_size_bytes));
923- }
924- return ;
925- }
926- }
927-
928776 bool is_grouped_convolution = filter_depth != in_depth;
929777 // check if filter is 1x1 and stride/dilation are all ones
930778 bool one_filter = true ;
0 commit comments