Skip to content

Commit 24d000a

Browse files
Reset conv_ops_impl.h to master - no changes needed for complex variable fix
The complex variable conj segfault fix only requires changes to: - dense_update_functor_gpu.cu.cc (GPU kernel instantiation) - resource_variable_ops.cc (GPU kernel registration) - resource_variable_ops_test.py (test cases) The conv_ops_impl.h file is unrelated to this fix and should not be modified.
1 parent e463925 commit 24d000a

1 file changed

Lines changed: 152 additions & 0 deletions

File tree

tensorflow/core/kernels/conv_ops_impl.h

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,41 @@ namespace tensorflow {
9090
typedef Eigen::ThreadPoolDevice CPUDevice;
9191
typedef Eigen::GpuDevice GPUDevice;
9292

93+
// Maximum tensor size (in bytes) that cuDNN can handle safely.
94+
// cuDNN has internal limits around 2GB for certain operations.
95+
// We use a conservative threshold to avoid CUDA invalid resource handle errors.
96+
constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024; // 2GB
97+
98+
// Helper function to check if the tensor size exceeds the safe limit for cuDNN.
99+
// Returns true if the tensor is too large and needs fallback processing.
100+
template <typename T>
101+
inline bool IsTensorTooLargeForCudnn(const Tensor& tensor) {
102+
int64_t tensor_size_bytes = tensor.NumElements() * sizeof(T);
103+
return tensor_size_bytes > kMaxCudnnTensorSizeBytes;
104+
}
105+
106+
// Helper function to compute the maximum batch size that keeps the tensor
107+
// under the cuDNN size limit.
108+
template <typename T>
109+
inline int64_t ComputeSafeBatchSize(const Tensor& tensor, int64_t current_batch,
110+
TensorFormat data_format) {
111+
if (current_batch <= 0) return 1;
112+
int64_t total_elements = tensor.NumElements();
113+
if (total_elements <= 0) return 1;
114+
// Handle edge case where total_elements < current_batch
115+
if (total_elements < current_batch) {
116+
// Each batch has less than 1 element on average, return 1
117+
return 1;
118+
}
119+
int64_t elements_per_batch = total_elements / current_batch;
120+
if (elements_per_batch <= 0) return 1;
121+
int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof(T);
122+
int64_t safe_batch = max_elements / elements_per_batch;
123+
// Ensure at least batch size of 1, and cap at current batch size
124+
return std::max(static_cast<int64_t>(1),
125+
std::min(safe_batch, current_batch));
126+
}
127+
93128
template <typename Device, typename T>
94129
struct LaunchGeneric {
95130
void operator()(OpKernelContext* ctx, const Tensor& input,
@@ -773,6 +808,123 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune,
773808
absl::InvalidArgumentError("filter must not have zero elements "
774809
"(i.e. all dimensions must be non-zero)"));
775810

811+
// Check if input tensor is too large for cuDNN and needs batch splitting.
812+
// This addresses CUDA invalid resource handle errors with large tensors.
813+
if (IsTensorTooLargeForCudnn<T>(input) && in_batch > 1) {
814+
int64_t safe_batch = ComputeSafeBatchSize<T>(input, in_batch, data_format);
815+
if (safe_batch < in_batch && safe_batch > 0) {
816+
VLOG(2) << "Input tensor too large for cuDNN, splitting batch from "
817+
<< in_batch << " to chunks of " << safe_batch;
818+
819+
// Process in batches to avoid cuDNN memory limits
820+
int64_t batch_idx = GetTensorDimIndex(data_format, 'N', input.dims());
821+
822+
// Validate batch dimension before proceeding
823+
OP_REQUIRES(context, batch_idx >= 0 && batch_idx < input.dims(),
824+
absl::InternalError("Invalid batch dimension index"));
825+
OP_REQUIRES(context, input.dim_size(batch_idx) > 0,
826+
absl::InternalError("Input batch dimension is zero"));
827+
OP_REQUIRES(context, output->dim_size(batch_idx) > 0,
828+
absl::InternalError("Output batch dimension is zero"));
829+
830+
for (int64_t start = 0; start < in_batch; start += safe_batch) {
831+
int64_t chunk_size = std::min(safe_batch, in_batch - start);
832+
833+
// Create sliced input tensor
834+
std::vector<int64_t> input_slice_shape;
835+
for (int i = 0; i < input.dims(); ++i) {
836+
if (i == batch_idx) {
837+
input_slice_shape.push_back(chunk_size);
838+
} else {
839+
input_slice_shape.push_back(input.dim_size(i));
840+
}
841+
}
842+
TensorShape input_slice_ts(input_slice_shape);
843+
Tensor input_slice;
844+
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
845+
input_slice_ts,
846+
&input_slice));
847+
848+
// Create sliced output tensor
849+
std::vector<int64_t> output_slice_shape;
850+
for (int i = 0; i < output->dims(); ++i) {
851+
if (i == batch_idx) {
852+
output_slice_shape.push_back(chunk_size);
853+
} else {
854+
output_slice_shape.push_back(output->dim_size(i));
855+
}
856+
}
857+
TensorShape output_slice_ts(output_slice_shape);
858+
Tensor output_slice;
859+
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
860+
output_slice_ts,
861+
&output_slice));
862+
863+
// Calculate elements per batch with validated dimensions
864+
int64_t input_batch_dim = input.dim_size(batch_idx);
865+
int64_t elements_per_batch = input.NumElements() / input_batch_dim;
866+
867+
// Validate bounds before pointer arithmetic
868+
int64_t input_offset = start * elements_per_batch;
869+
OP_REQUIRES(context, input_offset + chunk_size * elements_per_batch <=
870+
input.NumElements(),
871+
absl::InternalError("Input slice bounds check failed"));
872+
873+
// Copy input slice from input tensor (device to device)
874+
int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof(T);
875+
auto src_ptr = se::DeviceMemoryBase(
876+
const_cast<T*>(input.template flat<T>().data() + input_offset),
877+
copy_size_bytes);
878+
auto dst_ptr = se::DeviceMemoryBase(
879+
const_cast<T*>(input_slice.template flat<T>().data()),
880+
copy_size_bytes);
881+
OP_REQUIRES_OK(context,
882+
stream->MemcpyD2D(&dst_ptr, src_ptr, copy_size_bytes));
883+
884+
// Recursively call LaunchConvOpImpl with the smaller batch.
885+
// Safety note: The recursive call is guaranteed not to re-enter this
886+
// batch-splitting code path because:
887+
// 1. safe_batch is computed to keep sliced tensors under the size limit
888+
// 2. IsTensorTooLargeForCudnn will return false for the sliced tensor
889+
// 3. Even if it were to trigger, in_batch would equal chunk_size,
890+
// and safe_batch would equal chunk_size, so the condition
891+
// "safe_batch < in_batch" would be false
892+
LaunchConvOpImpl<T>(context, cudnn_use_autotune, input_slice, filter,
893+
dilations, strides, padding, explicit_paddings,
894+
data_format, &output_slice);
895+
896+
// Check for errors from recursive call
897+
if (!context->status().ok()) return;
898+
899+
// Calculate output elements per batch with validated dimensions
900+
int64_t output_batch_dim = output->dim_size(batch_idx);
901+
int64_t output_elements_per_batch =
902+
output->NumElements() / output_batch_dim;
903+
904+
// Validate bounds before pointer arithmetic
905+
int64_t output_offset = start * output_elements_per_batch;
906+
OP_REQUIRES(
907+
context,
908+
output_offset + chunk_size * output_elements_per_batch <=
909+
output->NumElements(),
910+
absl::InternalError("Output slice bounds check failed"));
911+
912+
// Copy output slice to output tensor (device to device)
913+
int64_t output_copy_size_bytes =
914+
chunk_size * output_elements_per_batch * sizeof(T);
915+
auto out_src_ptr = se::DeviceMemoryBase(
916+
const_cast<T*>(output_slice.template flat<T>().data()),
917+
output_copy_size_bytes);
918+
auto out_dst_ptr = se::DeviceMemoryBase(
919+
const_cast<T*>(output->template flat<T>().data() + output_offset),
920+
output_copy_size_bytes);
921+
OP_REQUIRES_OK(context, stream->MemcpyD2D(&out_dst_ptr, out_src_ptr,
922+
output_copy_size_bytes));
923+
}
924+
return;
925+
}
926+
}
927+
776928
bool is_grouped_convolution = filter_depth != in_depth;
777929
// check if filter is 1x1 and stride/dilation are all ones
778930
bool one_filter = true;

0 commit comments

Comments
 (0)