Skip to content

Commit ec747a8

Browse files
Correct imports of keras loss utils
PiperOrigin-RevId: 486795765
1 parent e334633 commit ec747a8

3 files changed

Lines changed: 5 additions & 8 deletions

File tree

tensorflow_privacy/privacy/estimators/binary_class_head.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""Binary class head for Estimator that allow integration with TF Privacy."""
1515

1616
import tensorflow as tf
17-
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
1817
from tensorflow_estimator.python.estimator import model_fn
1918
from tensorflow_estimator.python.estimator.canned import prediction_keys
2019
from tensorflow_estimator.python.estimator.export import export_output
@@ -55,7 +54,7 @@ def loss(self,
5554
labels = self._processed_labels(logits, labels)
5655
unweighted_loss, weights = self._unweighted_loss_and_weights(
5756
logits, labels, features)
58-
vector_training_loss = losses_utils.compute_weighted_loss(
57+
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
5958
unweighted_loss,
6059
sample_weight=weights,
6160
reduction=tf.keras.losses.Reduction.NONE)

tensorflow_privacy/privacy/estimators/multi_class_head.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""Multiclass head for Estimator that allow integration with TF Privacy."""
1515

1616
import tensorflow as tf
17-
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
1817
from tensorflow_estimator.python.estimator import model_fn
1918
from tensorflow_estimator.python.estimator.canned import prediction_keys
2019
from tensorflow_estimator.python.estimator.export import export_output
@@ -30,14 +29,14 @@ def __init__(self,
3029
n_classes,
3130
weight_column=None,
3231
label_vocabulary=None,
33-
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
32+
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
3433
loss_fn=None,
3534
name=None):
3635
super().__init__(
3736
n_classes=n_classes,
3837
weight_column=weight_column,
3938
label_vocabulary=label_vocabulary,
40-
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
39+
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
4140
loss_fn=loss_fn,
4241
name=name)
4342

@@ -55,7 +54,7 @@ def loss(self,
5554
labels = self._processed_labels(logits, labels)
5655
unweighted_loss, weights = self._unweighted_loss_and_weights(
5756
logits, labels, features)
58-
vector_training_loss = losses_utils.compute_weighted_loss(
57+
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
5958
unweighted_loss,
6059
sample_weight=weights,
6160
reduction=tf.keras.losses.Reduction.NONE)

tensorflow_privacy/privacy/estimators/multi_label_head.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""Multiclass head for Estimator that allow integration with TF Privacy."""
1515

1616
import tensorflow as tf
17-
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
1817
from tensorflow_estimator.python.estimator import model_fn
1918
from tensorflow_estimator.python.estimator.canned import prediction_keys
2019
from tensorflow_estimator.python.estimator.export import export_output
@@ -61,7 +60,7 @@ def loss(self,
6160
labels = self._processed_labels(logits, labels)
6261
unweighted_loss, weights = self._unweighted_loss_and_weights(
6362
logits, labels, features)
64-
vector_training_loss = losses_utils.compute_weighted_loss(
63+
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
6564
unweighted_loss,
6665
sample_weight=weights,
6766
reduction=tf.keras.losses.Reduction.NONE)

0 commit comments

Comments
 (0)