Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
2.20.0
Custom code
Yes
OS platform and distribution
Fedora
Mobile device
No response
Python version
3.12.4
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
The logs told me to create a bug. I'm working on jax2tf saved model related work.
tf_env.txt
== check python ====================================================
python version: 3.12.4
python branch:
python build version: ('main', 'Jul 4 2024 01:10:55')
python compiler version: GCC 14.1.1 20240701 (Red Hat 14.1.1-7)
python implementation: CPython
== check os platform ===============================================
os: Linux
os kernel version: #1 SMP PREEMPT_DYNAMIC Sun Oct 19 18:47:49 UTC 2025
os release version: 6.17.4-200.fc42.x86_64
os platform: Linux-6.17.4-200.fc42.x86_64-x86_64-with-glibc2.41
freedesktop os release: {'NAME': 'Fedora Linux', 'ID': 'fedora', 'PRETTY_NAME': 'Fedora Linux 42 (Server Edition)', 'VERSION': '42 (Server Edition)', 'RELEASE_TYPE': 'stable', 'VERSION_ID': '42', 'VERSION_CODENAME': '', 'PLATFORM_ID': 'platform:f42', 'ANSI_COLOR': '0;38;2;60;110;180', 'LOGO': 'fedora-logo-icon', 'CPE_NAME': 'cpe:/o:fedoraproject:fedora:42', 'HOME_URL': 'https://fedoraproject.org/', 'DOCUMENTATION_URL': 'https://docs.fedoraproject.org/en-US/fedora/f42/', 'SUPPORT_URL': 'https://ask.fedoraproject.org/', 'BUG_REPORT_URL': 'https://bugzilla.redhat.com/', 'REDHAT_BUGZILLA_PRODUCT': 'Fedora', 'REDHAT_BUGZILLA_PRODUCT_VERSION': '42', 'REDHAT_SUPPORT_PRODUCT': 'Fedora', 'REDHAT_SUPPORT_PRODUCT_VERSION': '42', 'SUPPORT_END': '2026-05-13', 'VARIANT': 'Server Edition', 'VARIANT_ID': 'server'}
mac version: ('', ('', '', ''), '')
uname: uname_result(system='Linux', node='distml', release='6.17.4-200.fc42.x86_64', version='#1 SMP PREEMPT_DYNAMIC Sun Oct 19 18:47:49 UTC 2025', machine='x86_64')
architecture: ('64bit', 'ELF')
machine: x86_64
== are we in docker ================================================
No
== c++ compiler ====================================================
/usr/bin/c++
c++ (GCC) 15.2.1 20250808 (Red Hat 15.2.1-1)
Copyright (C) 2025 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
== check pips ======================================================
numpy 2.3.1
protobuf 6.32.0
tensorflow 2.20.0
tf_nightly 2.21.0.dev20250829
== check for virtualenv ============================================
Running inside a virtual environment.
== tensorflow import ===============================================
2025-11-27 22:22:54.895643: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-27 22:22:54.924922: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-27 22:22:56.774100: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1764300176.775098 1685331 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3409 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:05:00.0, compute capability: 8.6
2025-11-27 22:22:56.775413: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
I0000 00:00:1764300176.776320 1685331 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 10138 MB memory: -> device: 1, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6
tf.version.VERSION = 2.20.0
tf.version.GIT_VERSION = v2.20.0-rc0-4-g72fbba3d20f
tf.version.COMPILER_VERSION = Ubuntu Clang 18.1.8 (++20240731024944+3b5b5c1ec4a3-1~exp1~20240731145000.144)
Sanity check: <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>
libcudnn not found
== env =============================================================
LD_LIBRARY_PATH is unset
DYLD_LIBRARY_PATH is unset
== nvidia-smi ======================================================
Thu Nov 27 22:22:59 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A |
| 36% 32C P8 34W / 170W | 127MiB / 12288MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 3090 Off | 00000000:05:00.0 Off | N/A |
| 0% 37C P5 87W / 350W | 18915MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 78186 G /usr/bin/gnome-shell 2MiB |
| 0 N/A N/A 1657724 C ...12.4/envs/dash-env/bin/python 106MiB |
| 1 N/A N/A 78186 G /usr/bin/gnome-shell 322MiB |
| 1 N/A N/A 78459 G /usr/bin/Xwayland 6MiB |
| 1 N/A N/A 1403173 G ...per --variations-seed-version 100MiB |
| 1 N/A N/A 1657724 C ...12.4/envs/dash-env/bin/python 18368MiB |
+-----------------------------------------------------------------------------------------+
== cuda libs =======================================================
/usr/local/lib/ollama/cuda_v11/libcudart.so.11.3.109
/usr/local/lib/ollama/cuda_v12/libcudart.so.12.8.90
/usr/local/cuda-13.0/targets/x86_64-linux/lib/libcudart_static.a
/usr/local/cuda-13.0/targets/x86_64-linux/lib/libcudart.so.13.0.96
== tensorflow installation =========================================
Name: tensorflow
Version: 2.20.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /home/pdodeja/.pyenv/versions/3.12.4/envs/dash-env/lib/python3.12/site-packages
Required-by:
== tf_nightly installation =========================================
Name: tf_nightly
Version: 2.21.0.dev20250829
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /home/pdodeja/.pyenv/versions/3.12.4/envs/dash-env/lib/python3.12/site-packages
Required-by:
== python version ==================================================
(major, minor, micro, releaselevel, serial)
(3, 12, 4, 'final', 0)
== bazel version ===================================================
Build label: 6.1.0
Build time: Mon Mar 6 17:09:47 2023 (1678122587)
Build timestamp: 1678122587
Build timestamp as int: 1678122587
debug_output.txt
Standalone code to reproduce the issue
# For saving the model
import jax
import jax.numpy as jnp
import numpy as np
import optax # A popular optimization library for JAX
import tensorflow as tf
from jax.experimental import jax2tf
print("Using JAX version:", jax.__version__)
print("Using TensorFlow version:", tf.__version__)
print("Using Optax version:", optax.__version__)
# --- 1. Data Generation ---
# Let's create some simple data for y = 2x_1 - 3x_2 + 5
true_w = np.array([[2.0], [-3.0]])
true_b = 5.0
num_examples = 1000
# Generate random input data
key = jax.random.PRNGKey(42)
x_key, noise_key = jax.random.split(key)
X = jax.random.normal(x_key, (num_examples, 2))
# Generate corresponding labels with some noise
noise = jax.random.normal(noise_key, (num_examples, 1)) * 0.1
y = jnp.dot(X, true_w) + true_b + noise
print("\n--- Data Generation ---")
print(f"Input data shape (X): {X.shape}")
print(f"Label data shape (y): {y.shape}")
# --- 2. Create a tf.data Pipeline ---
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.shuffle(buffer_size=num_examples).batch(batch_size)
print("\n--- tf.data Pipeline ---")
print("Created a tf.data.Dataset object.")
# Inspect one batch
for x_batch, y_batch in dataset.take(1):
print(f"One batch of X has shape: {x_batch.shape}")
print(f"One batch of y has shape: {y_batch.shape}")
# --- 3. JAX Model, Loss, and Update Step ---
# Define the same simple linear model
def linear_model(params, x):
return jnp.dot(x, params['w']) + params['b']
# Define the loss function (Mean Squared Error)
def mse_loss(params, x, y_true):
y_pred = linear_model(params, x)
return jnp.mean((y_pred - y_true)**2)
# Define the update step for training
# This function calculates loss, computes gradients, and applies updates
@jax.jit
def update_step(params, opt_state, x, y):
# Calculate the loss and gradients
loss, grads = jax.value_and_grad(mse_loss)(params, x, y)
# Update the parameters and optimizer state
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# --- 4. The JAX Training Loop ---
# Initialize the model's parameters randomly
w_key, b_key = jax.random.split(key, 2)
params = {
'w': jax.random.normal(w_key, (2, 1)),
'b': jax.random.normal(b_key, (1,))
}
# Initialize the optimizer
learning_rate = 0.01
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)
print("\n--- JAX Training ---")
print("Initial random parameters:")
print(params)
# Training loop
epochs = 100
for epoch in range(epochs):
epoch_loss = 0.0
# *** Bridge from tf.data to JAX ***
# We use .as_numpy_iterator() to feed the TF pipeline into the JAX loop
for x_batch, y_batch in dataset.as_numpy_iterator():
params, opt_state, loss = update_step(params, opt_state, x_batch, y_batch)
epoch_loss += loss
avg_loss = epoch_loss / len(dataset)
if (epoch + 1) % 2 == 0:
print(f"Epoch {epoch+1:2d}, Avg Loss: {avg_loss:.4f}")
print("\nTraining complete. Final trained parameters:")
print(params)
print(f"(Ground truth was w=[[2.], [-3.]], b=5.0)")
# --- 5. Convert Trained JAX Model to TensorFlow ---
# The function to convert is the simple forward pass
tf_predict_function = jax2tf.convert(
linear_model,
with_gradient=False,
polymorphic_shapes=[None, '(b, 2)'] # params, input_x
)
# Create a tf.Module to wrap the function and trained parameters
class JaxTrainedModule(tf.Module):
def __init__(self, params):
super(JaxTrainedModule, self).__init__()
# Convert JAX params to TensorFlow variables. This is important for saving.
self.params = tf.nest.map_structure(tf.Variable, params)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None, 2], dtype=tf.float32)
])
def __call__(self, x):
# The converted function expects a dictionary of tensors
return tf_predict_function(self.params, x)
# Instantiate the module with our final trained JAX parameters
tf_model = JaxTrainedModule(params)
# Test the converted model on a sample
test_input = np.array([[1.0, 1.0], [2.0, 3.0]], dtype=np.float32)
tf_output = tf_model(test_input)
print("\n--- Conversion and Verification ---")
print("Converted model prediction on test data:")
print(tf_output.numpy())
# Expected output for [[1,1], [2,3]] with true_w/b is [[4], [0]]
# Our model's output should be very close.
# --- 6. Save the model in SavedModel format ---
model_dir = './my_jax_model'
tf.saved_model.save(tf_model, model_dir)
print(f"\nModel successfully saved in SavedModel format at: {model_dir}")
print(f"You can inspect it with: !saved_model_cli show --dir {model_dir} --all")
Relevant log output
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
2.20.0
Custom code
Yes
OS platform and distribution
Fedora
Mobile device
No response
Python version
3.12.4
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
The logs told me to create a bug. I'm working on jax2tf saved model related work.
tf_env.txt
debug_output.txt
Standalone code to reproduce the issue
Relevant log output