Skip to content

Commit cf23723

Browse files
Fix segmentation fault with complex Variable operations and assign_add
Fixes tensorflow#105367 The issue was that complex types (complex64, complex128) were missing from: 1. GPU DenseUpdate functor template instantiations for ADD/SUB operations 2. GPU kernel registrations for AssignAddVariableOp and AssignSubVariableOp This caused a segmentation fault when using assign_add on complex Variables, particularly when combined with tf.raw_ops.Conj operations. Changes: - Added TF_CALL_COMPLEX_TYPES to dense_update_functor_gpu.cu.cc for ADD/SUB - Added TF_CALL_COMPLEX_TYPES to GPU kernel registrations in resource_variable_ops.cc - Added comprehensive test cases for complex variable assign_add operations
1 parent c4153df commit cf23723

4 files changed

Lines changed: 74 additions & 0 deletions

File tree

tensorflow/core/kernels/dense_update_functor_gpu.cu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct DenseUpdate<GPUDevice, T, SUB> {
5757
template struct functor::DenseUpdate<GPUDevice, T, SUB>;
5858
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
5959
TF_CALL_INTEGRAL_TYPES(DEFINE_GPU_KERNELS);
60+
TF_CALL_COMPLEX_TYPES(DEFINE_GPU_KERNELS);
6061
TF_CALL_float8_e5m2(DEFINE_GPU_KERNELS);
6162
TF_CALL_float8_e4m3fn(DEFINE_GPU_KERNELS);
6263
#undef DEFINE_GPU_KERNELS

tensorflow/core/kernels/resource_variable_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
681681

682682
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
683683
TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS);
684+
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
684685
#undef REGISTER_GPU_KERNELS
685686
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
686687

tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,5 +1864,64 @@ def testGatherBatchDimsNeg(self):
18641864
)
18651865
self.evaluate(result)
18661866

1867+
@test_util.run_in_graph_and_eager_modes
1868+
@test_util.run_gpu_only
1869+
def testComplexVariableAssignAddWithConj(self):
1870+
"""Test for issue #105367: Segfault with complex Variable, Conj, and assign_add."""
1871+
# Test with complex64
1872+
input_data_64 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex64)
1873+
var_64 = resource_variable_ops.ResourceVariable(input_data_64, dtype=dtypes.complex64)
1874+
self.evaluate(var_64.initializer)
1875+
1876+
conj_result_64 = math_ops.conj(input_data_64)
1877+
assign_add_op_64 = var_64.assign_add(conj_result_64)
1878+
result_64 = self.evaluate(assign_add_op_64)
1879+
1880+
# Expected: [1+2j, 3+4j] + [1-2j, 3-4j] = [2+0j, 6+0j]
1881+
expected_64 = np.array([2+0j, 6+0j], dtype=np.complex64)
1882+
self.assertAllClose(result_64, expected_64)
1883+
1884+
# Test with complex128
1885+
input_data_128 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex128)
1886+
var_128 = resource_variable_ops.ResourceVariable(input_data_128, dtype=dtypes.complex128)
1887+
self.evaluate(var_128.initializer)
1888+
1889+
conj_result_128 = math_ops.conj(input_data_128)
1890+
assign_add_op_128 = var_128.assign_add(conj_result_128)
1891+
result_128 = self.evaluate(assign_add_op_128)
1892+
1893+
# Expected: [1+2j, 3+4j] + [1-2j, 3-4j] = [2+0j, 6+0j]
1894+
expected_128 = np.array([2+0j, 6+0j], dtype=np.complex128)
1895+
self.assertAllClose(result_128, expected_128)
1896+
1897+
@test_util.run_in_graph_and_eager_modes
1898+
def testComplexVariableAssignAddCPU(self):
1899+
"""Test complex Variable assign_add on CPU."""
1900+
# Test with complex64
1901+
input_data_64 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex64)
1902+
var_64 = resource_variable_ops.ResourceVariable(input_data_64, dtype=dtypes.complex64)
1903+
self.evaluate(var_64.initializer)
1904+
1905+
delta_64 = constant_op.constant([0.5 - 1j, 1 + 0.5j], dtype=dtypes.complex64)
1906+
assign_add_op_64 = var_64.assign_add(delta_64)
1907+
result_64 = self.evaluate(assign_add_op_64)
1908+
1909+
# Expected: [1+2j, 3+4j] + [0.5-1j, 1+0.5j] = [1.5+1j, 4+4.5j]
1910+
expected_64 = np.array([1.5+1j, 4+4.5j], dtype=np.complex64)
1911+
self.assertAllClose(result_64, expected_64)
1912+
1913+
# Test with complex128
1914+
input_data_128 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex128)
1915+
var_128 = resource_variable_ops.ResourceVariable(input_data_128, dtype=dtypes.complex128)
1916+
self.evaluate(var_128.initializer)
1917+
1918+
delta_128 = constant_op.constant([0.5 - 1j, 1 + 0.5j], dtype=dtypes.complex128)
1919+
assign_add_op_128 = var_128.assign_add(delta_128)
1920+
result_128 = self.evaluate(assign_add_op_128)
1921+
1922+
# Expected: [1+2j, 3+4j] + [0.5-1j, 1+0.5j] = [1.5+1j, 4+4.5j]
1923+
expected_128 = np.array([1.5+1j, 4+4.5j], dtype=np.complex128)
1924+
self.assertAllClose(result_128, expected_128)
1925+
18671926
if __name__ == "__main__":
18681927
test.main()

test_issue_105367.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import tensorflow as tf
2+
3+
print("TensorFlow version:", tf.__version__)
4+
print("Testing complex Variable operations with tf.raw_ops.Conj and assign_add...")
5+
6+
try:
7+
input_data = tf.constant([1 + 2j, 3 + 4j], dtype=tf.complex64)
8+
var = tf.Variable(input_data, dtype=tf.complex64)
9+
conj_result = tf.raw_ops.Conj(input=input_data)
10+
assign_add_op = var.assign_add(conj_result)
11+
print("Success! Result:", assign_add_op.numpy())
12+
except Exception as e:
13+
print(f"Error occurred: {type(e).__name__}: {e}")

0 commit comments

Comments
 (0)