Skip to content

Commit ab52378

Browse files
Fix imports in mixed_dict_keys_test to use internal TensorFlow APIs
1 parent 30789ad commit ab52378

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

tensorflow/python/util/mixed_dict_keys_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from tensorflow.python.framework import constant_op
2+
from tensorflow.python.util import nest_util
13
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +20,6 @@
1820
fails when returning dictionaries with mixed key types (e.g., strings and integers).
1921
"""
2022

21-
import tensorflow as tf
2223
from tensorflow.python.platform import test
2324
from tensorflow.python.util import nest
2425

@@ -47,7 +48,7 @@ def simple_mixed_dict(x):
4748
results[123] = x + 1
4849
return results
4950

50-
input_tensor = tf.constant([1.0, 2.0, 3.0])
51+
input_tensor = constant_op.constant([1.0, 2.0, 3.0])
5152
output = simple_mixed_dict(input_tensor)
5253

5354
self.assertIn('string_key', output)
@@ -86,7 +87,7 @@ def multi_type_dict(x):
8687
results['str3'] = x + 5
8788
return results
8889

89-
input_tensor = tf.constant(10.0)
90+
input_tensor = constant_op.constant(10.0)
9091
output = multi_type_dict(input_tensor)
9192

9293
# Verify all keys are present
@@ -117,7 +118,7 @@ def nested_mixed_dict(x):
117118
}
118119
return outer
119120

120-
input_tensor = tf.constant(5.0)
121+
input_tensor = constant_op.constant(5.0)
121122
output = nested_mixed_dict(input_tensor)
122123

123124
self.assertIn('outer', output)
@@ -145,7 +146,7 @@ def no_xla_mixed_dict(x):
145146
results[123] = x + 1
146147
return results
147148

148-
input_tensor = tf.constant([1.0, 2.0])
149+
input_tensor = constant_op.constant([1.0, 2.0])
149150
output = no_xla_mixed_dict(input_tensor)
150151

151152
self.assertIn('string_key', output)
@@ -162,7 +163,7 @@ def consistent_dict(x):
162163
results[1] = x + 3
163164
return results
164165

165-
input_tensor = tf.constant(1.0)
166+
input_tensor = constant_op.constant(1.0)
166167

167168
# Call multiple times and verify same order
168169
output1 = consistent_dict(input_tensor)

0 commit comments

Comments
 (0)