Skip to content

Commit 0a85870

Browse files
Remove unrelated cuDNN tensor size checking code from conv_ops_impl.h
Address code review feedback: limit changes only to the complex variable conj segfault fix. The cuDNN batch splitting code was unrelated to the initializers issue and has been removed to keep the PR focused.
1 parent 23f3001 commit 0a85870

1 file changed

Lines changed: 0 additions & 152 deletions

File tree

tensorflow/core/kernels/conv_ops_impl.h

Lines changed: 0 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -90,41 +90,6 @@ namespace tensorflow {
9090
typedef Eigen::ThreadPoolDevice CPUDevice;
9191
typedef Eigen::GpuDevice GPUDevice;
9292

93-
// Maximum tensor size (in bytes) that cuDNN can handle safely.
94-
// cuDNN has internal limits around 2GB for certain operations.
95-
// We use a conservative threshold to avoid CUDA invalid resource handle errors.
96-
constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024; // 2GB
97-
98-
// Helper function to check if the tensor size exceeds the safe limit for cuDNN.
99-
// Returns true if the tensor is too large and needs fallback processing.
100-
template <typename T>
101-
inline bool IsTensorTooLargeForCudnn(const Tensor& tensor) {
102-
int64_t tensor_size_bytes = tensor.NumElements() * sizeof(T);
103-
return tensor_size_bytes > kMaxCudnnTensorSizeBytes;
104-
}
105-
106-
// Helper function to compute the maximum batch size that keeps the tensor
107-
// under the cuDNN size limit.
108-
template <typename T>
109-
inline int64_t ComputeSafeBatchSize(const Tensor& tensor, int64_t current_batch,
110-
TensorFormat data_format) {
111-
if (current_batch <= 0) return 1;
112-
int64_t total_elements = tensor.NumElements();
113-
if (total_elements <= 0) return 1;
114-
// Handle edge case where total_elements < current_batch
115-
if (total_elements < current_batch) {
116-
// Each batch has less than 1 element on average, return 1
117-
return 1;
118-
}
119-
int64_t elements_per_batch = total_elements / current_batch;
120-
if (elements_per_batch <= 0) return 1;
121-
int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof(T);
122-
int64_t safe_batch = max_elements / elements_per_batch;
123-
// Ensure at least batch size of 1, and cap at current batch size
124-
return std::max(static_cast<int64_t>(1),
125-
std::min(safe_batch, current_batch));
126-
}
127-
12893
template <typename Device, typename T>
12994
struct LaunchGeneric {
13095
void operator()(OpKernelContext* ctx, const Tensor& input,
@@ -808,123 +773,6 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune,
808773
absl::InvalidArgumentError("filter must not have zero elements "
809774
"(i.e. all dimensions must be non-zero)"));
810775

811-
// Check if input tensor is too large for cuDNN and needs batch splitting.
812-
// This addresses CUDA invalid resource handle errors with large tensors.
813-
if (IsTensorTooLargeForCudnn<T>(input) && in_batch > 1) {
814-
int64_t safe_batch = ComputeSafeBatchSize<T>(input, in_batch, data_format);
815-
if (safe_batch < in_batch && safe_batch > 0) {
816-
VLOG(2) << "Input tensor too large for cuDNN, splitting batch from "
817-
<< in_batch << " to chunks of " << safe_batch;
818-
819-
// Process in batches to avoid cuDNN memory limits
820-
int64_t batch_idx = GetTensorDimIndex(data_format, 'N', input.dims());
821-
822-
// Validate batch dimension before proceeding
823-
OP_REQUIRES(context, batch_idx >= 0 && batch_idx < input.dims(),
824-
absl::InternalError("Invalid batch dimension index"));
825-
OP_REQUIRES(context, input.dim_size(batch_idx) > 0,
826-
absl::InternalError("Input batch dimension is zero"));
827-
OP_REQUIRES(context, output->dim_size(batch_idx) > 0,
828-
absl::InternalError("Output batch dimension is zero"));
829-
830-
for (int64_t start = 0; start < in_batch; start += safe_batch) {
831-
int64_t chunk_size = std::min(safe_batch, in_batch - start);
832-
833-
// Create sliced input tensor
834-
std::vector<int64_t> input_slice_shape;
835-
for (int i = 0; i < input.dims(); ++i) {
836-
if (i == batch_idx) {
837-
input_slice_shape.push_back(chunk_size);
838-
} else {
839-
input_slice_shape.push_back(input.dim_size(i));
840-
}
841-
}
842-
TensorShape input_slice_ts(input_slice_shape);
843-
Tensor input_slice;
844-
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
845-
input_slice_ts,
846-
&input_slice));
847-
848-
// Create sliced output tensor
849-
std::vector<int64_t> output_slice_shape;
850-
for (int i = 0; i < output->dims(); ++i) {
851-
if (i == batch_idx) {
852-
output_slice_shape.push_back(chunk_size);
853-
} else {
854-
output_slice_shape.push_back(output->dim_size(i));
855-
}
856-
}
857-
TensorShape output_slice_ts(output_slice_shape);
858-
Tensor output_slice;
859-
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
860-
output_slice_ts,
861-
&output_slice));
862-
863-
// Calculate elements per batch with validated dimensions
864-
int64_t input_batch_dim = input.dim_size(batch_idx);
865-
int64_t elements_per_batch = input.NumElements() / input_batch_dim;
866-
867-
// Validate bounds before pointer arithmetic
868-
int64_t input_offset = start * elements_per_batch;
869-
OP_REQUIRES(context, input_offset + chunk_size * elements_per_batch <=
870-
input.NumElements(),
871-
absl::InternalError("Input slice bounds check failed"));
872-
873-
// Copy input slice from input tensor (device to device)
874-
int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof(T);
875-
auto src_ptr = se::DeviceMemoryBase(
876-
const_cast<T*>(input.template flat<T>().data() + input_offset),
877-
copy_size_bytes);
878-
auto dst_ptr = se::DeviceMemoryBase(
879-
const_cast<T*>(input_slice.template flat<T>().data()),
880-
copy_size_bytes);
881-
OP_REQUIRES_OK(context,
882-
stream->MemcpyD2D(&dst_ptr, src_ptr, copy_size_bytes));
883-
884-
// Recursively call LaunchConvOpImpl with the smaller batch.
885-
// Safety note: The recursive call is guaranteed not to re-enter this
886-
// batch-splitting code path because:
887-
// 1. safe_batch is computed to keep sliced tensors under the size limit
888-
// 2. IsTensorTooLargeForCudnn will return false for the sliced tensor
889-
// 3. Even if it were to trigger, in_batch would equal chunk_size,
890-
// and safe_batch would equal chunk_size, so the condition
891-
// "safe_batch < in_batch" would be false
892-
LaunchConvOpImpl<T>(context, cudnn_use_autotune, input_slice, filter,
893-
dilations, strides, padding, explicit_paddings,
894-
data_format, &output_slice);
895-
896-
// Check for errors from recursive call
897-
if (!context->status().ok()) return;
898-
899-
// Calculate output elements per batch with validated dimensions
900-
int64_t output_batch_dim = output->dim_size(batch_idx);
901-
int64_t output_elements_per_batch =
902-
output->NumElements() / output_batch_dim;
903-
904-
// Validate bounds before pointer arithmetic
905-
int64_t output_offset = start * output_elements_per_batch;
906-
OP_REQUIRES(
907-
context,
908-
output_offset + chunk_size * output_elements_per_batch <=
909-
output->NumElements(),
910-
absl::InternalError("Output slice bounds check failed"));
911-
912-
// Copy output slice to output tensor (device to device)
913-
int64_t output_copy_size_bytes =
914-
chunk_size * output_elements_per_batch * sizeof(T);
915-
auto out_src_ptr = se::DeviceMemoryBase(
916-
const_cast<T*>(output_slice.template flat<T>().data()),
917-
output_copy_size_bytes);
918-
auto out_dst_ptr = se::DeviceMemoryBase(
919-
const_cast<T*>(output->template flat<T>().data() + output_offset),
920-
output_copy_size_bytes);
921-
OP_REQUIRES_OK(context, stream->MemcpyD2D(&out_dst_ptr, out_src_ptr,
922-
output_copy_size_bytes));
923-
}
924-
return;
925-
}
926-
}
927-
928776
bool is_grouped_convolution = filter_depth != in_depth;
929777
// check if filter is 1x1 and stride/dilation are all ones
930778
bool one_filter = true;

0 commit comments

Comments
 (0)