Skip to content

TFRS and TF 2.16 compatability #757

@Daard

Description

@Daard

System information

OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Ubuntu 20.04
TensorFlow version and how it was installed (source or binary):
tensorflow 2.16
TensorFlow-Recommenders-Addons version and how it was installed (source or binary):
tfrs v0.7.3
Python version:
3.10
Is GPU used? (yes/no):
yes
Describe the bug

I am using tfrs with tf 2.16. I have faced this exception when model.fit(...) is called.

Exception encountered when calling Retrieval.call().

Can not convert a NoneType into a Tensor or Operation.
I have managed to localise the problem - the exception is raised after I added batch metric to the Retrieval task.
All my code was working fine on TF 2.14 but I need to remove batch metric from the task on TF 2.16.

class LogitsAccuracy(tf.keras.metrics.Accuracy):
    """ Custom loss for diagonal y_true and the matrix of query-candidates scores.
    def update_state(self, y_true, y_pred, sample_weight=None):
        new_preds = tf.argmax(y_pred, axis=-1)
        # new_trues = tf.range(tf.linalg.trace(y_true))
        batch_size = tf.shape(y_pred)[0]
        new_trues = tf.range(batch_size)

        # Explicit casts to ensure correct dtypes
        new_preds = tf.cast(new_preds, tf.int32)
        new_trues = tf.cast(new_trues, tf.int32)

        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)

        return super().update_state(new_trues, new_preds, sample_weight=sample_weight)
self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
            # batch_metrics=[LogitsAccuracy(name='accuracy')]
        )

How can I fix this problem? I have tested other standard metrics, but results were the same - raising exception.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions