Skip to content

Commit dd33a98

Browse files
Fix XLA mixed dict-key sorting: update nest_util and add mixed_dict_keys_test
1 parent 15fc9e6 commit dd33a98

2 files changed

Lines changed: 198 additions & 6 deletions

File tree

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from tensorflow.python.framework import constant_op
2+
from tensorflow.python.util import nest_util
3+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ==============================================================================
17+
"""Tests for XLA JIT compilation with mixed-type dictionary keys.
18+
19+
This test validates the fix for issue #105333 where @tf.function(jit_compile=True)
20+
fails when returning dictionaries with mixed key types (e.g., strings and integers).
21+
"""
22+
23+
from tensorflow.python.platform import test
24+
from tensorflow.python.util import nest
25+
26+
27+
class XLAMixedDictKeysTest(test.TestCase):
28+
"""Test XLA JIT compilation with mixed-type dictionary keys."""
29+
30+
def test_mixed_string_int_keys_flatten(self):
31+
"""Test flattening dict with mixed string and int keys."""
32+
mixed_dict = {'string_key': 1, 123: 2, 'another': 3, 456: 4}
33+
flattened = nest.flatten(mixed_dict)
34+
# Should flatten successfully with deterministic order
35+
# Keys sorted by type name first (int < str), then by value
36+
self.assertEqual(len(flattened), 4)
37+
self.assertIn(1, flattened)
38+
self.assertIn(2, flattened)
39+
self.assertIn(3, flattened)
40+
self.assertIn(4, flattened)
41+
42+
def test_mixed_keys_with_xla_simple(self):
43+
"""Test simple XLA function with mixed dict keys."""
44+
@tf.function(jit_compile=True)
45+
def simple_mixed_dict(x):
46+
results = {}
47+
results['string_key'] = x
48+
results[123] = x + 1
49+
return results
50+
51+
input_tensor = constant_op.constant([1.0, 2.0, 3.0])
52+
output = simple_mixed_dict(input_tensor)
53+
54+
self.assertIn('string_key', output)
55+
self.assertIn(123, output)
56+
self.assertAllClose(output['string_key'], [1.0, 2.0, 3.0])
57+
self.assertAllClose(output[123], [2.0, 3.0, 4.0])
58+
59+
def test_mixed_keys_with_xla_in_model(self):
60+
"""Test XLA with mixed dict keys in Keras model (original issue #105333)."""
61+
class SimpleModel(tf.keras.Model):
62+
@tf.function(jit_compile=True)
63+
def call(self, x):
64+
results = {}
65+
results['string_key'] = x
66+
results[123] = x + 1
67+
return x, results
68+
69+
model = SimpleModel()
70+
input_tensor = tf.random.normal([2, 16, 16, 16, 32])
71+
output_tensor, output_dict = model(input_tensor)
72+
73+
self.assertEqual(output_tensor.shape, (2, 16, 16, 16, 32))
74+
self.assertIn('string_key', output_dict)
75+
self.assertIn(123, output_dict)
76+
77+
def test_multiple_mixed_types(self):
78+
"""Test dict with multiple mixed key types."""
79+
@tf.function(jit_compile=True)
80+
def multi_type_dict(x):
81+
results = {}
82+
results['str1'] = x
83+
results[1] = x + 1
84+
results['str2'] = x + 2
85+
results[2] = x + 3
86+
results[3] = x + 4
87+
results['str3'] = x + 5
88+
return results
89+
90+
input_tensor = constant_op.constant(10.0)
91+
output = multi_type_dict(input_tensor)
92+
93+
# Verify all keys are present
94+
self.assertIn('str1', output)
95+
self.assertIn('str2', output)
96+
self.assertIn('str3', output)
97+
self.assertIn(1, output)
98+
self.assertIn(2, output)
99+
self.assertIn(3, output)
100+
101+
# Verify values
102+
self.assertAlmostEqual(output['str1'].numpy(), 10.0)
103+
self.assertAlmostEqual(output[1].numpy(), 11.0)
104+
self.assertAlmostEqual(output['str2'].numpy(), 12.0)
105+
self.assertAlmostEqual(output[2].numpy(), 13.0)
106+
107+
def test_nested_mixed_keys(self):
108+
"""Test nested dicts with mixed keys."""
109+
@tf.function(jit_compile=True)
110+
def nested_mixed_dict(x):
111+
inner = {
112+
'inner_str': x,
113+
100: x + 1
114+
}
115+
outer = {
116+
'outer': inner,
117+
200: x + 2
118+
}
119+
return outer
120+
121+
input_tensor = constant_op.constant(5.0)
122+
output = nested_mixed_dict(input_tensor)
123+
124+
self.assertIn('outer', output)
125+
self.assertIn(200, output)
126+
self.assertIn('inner_str', output['outer'])
127+
self.assertIn(100, output['outer'])
128+
129+
def test_pack_sequence_as_with_mixed_keys(self):
130+
"""Test pack_sequence_as with mixed key types."""
131+
structure = {'a': 1, 10: 2, 'b': 3, 20: 4}
132+
flat_sequence = [100, 200, 300, 400]
133+
134+
packed = nest.pack_sequence_as(structure, flat_sequence)
135+
136+
# Verify repacking works correctly
137+
self.assertEqual(len(packed), 4)
138+
# Values should be assigned in sorted key order (int keys first, then str keys)
139+
140+
def test_without_xla_still_works(self):
141+
"""Verify mixed keys work without XLA as well."""
142+
@tf.function(jit_compile=False)
143+
def no_xla_mixed_dict(x):
144+
results = {}
145+
results['string_key'] = x
146+
results[123] = x + 1
147+
return results
148+
149+
input_tensor = constant_op.constant([1.0, 2.0])
150+
output = no_xla_mixed_dict(input_tensor)
151+
152+
self.assertIn('string_key', output)
153+
self.assertIn(123, output)
154+
155+
def test_consistent_ordering(self):
156+
"""Ensure consistent ordering across multiple calls."""
157+
@tf.function(jit_compile=True)
158+
def consistent_dict(x):
159+
results = {}
160+
results['z'] = x
161+
results[3] = x + 1
162+
results['a'] = x + 2
163+
results[1] = x + 3
164+
return results
165+
166+
input_tensor = constant_op.constant(1.0)
167+
168+
# Call multiple times and verify same order
169+
output1 = consistent_dict(input_tensor)
170+
output2 = consistent_dict(input_tensor)
171+
output3 = consistent_dict(input_tensor)
172+
173+
keys1 = sorted(output1.keys(), key=lambda x: (type(x).__name__, x))
174+
keys2 = sorted(output2.keys(), key=lambda x: (type(x).__name__, x))
175+
keys3 = sorted(output3.keys(), key=lambda x: (type(x).__name__, x))
176+
177+
self.assertEqual(keys1, keys2)
178+
self.assertEqual(keys2, keys3)
179+
180+
181+
if __name__ == '__main__':
182+
test.main()

tensorflow/python/util/nest_util.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,19 +272,29 @@ def _tf_core_sorted(dict_):
272272
try:
273273
return sorted(dict_.keys())
274274
except TypeError:
275-
# pylint: disable=raise-missing-from
276-
raise TypeError("nest only supports dicts with sortable keys.")
275+
# If direct sorting fails (e.g., mixed types like int and str),
276+
# try sorting by (type name, key) to group by type first, then by value
277+
try:
278+
return sorted(dict_.keys(), key=lambda x: (type(x).__name__, x))
279+
except TypeError:
280+
# If that still fails, fall back to sorting by string representation
281+
# This ensures deterministic ordering even with complex mixed types
282+
return sorted(dict_.keys(), key=lambda x: (type(x).__name__, str(x)))
277283

278284

279285
def _tf_data_sorted(dict_):
280286
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
281287
try:
282288
return sorted(list(dict_))
283289
except TypeError as e:
284-
# pylint: disable=raise-missing-from
285-
raise TypeError(
286-
f"nest only supports dicts with sortable keys. Error: {e.message}"
287-
)
290+
# If direct sorting fails (e.g., mixed types like int and str),
291+
# try sorting by (type name, key) to group by type first, then by value
292+
try:
293+
return sorted(list(dict_), key=lambda x: (type(x).__name__, x))
294+
except TypeError:
295+
# If that still fails, fall back to sorting by string representation
296+
# This ensures deterministic ordering even with complex mixed types
297+
return sorted(list(dict_), key=lambda x: (type(x).__name__, str(x)))
288298

289299

290300
def yield_value(modality, iterable):

0 commit comments

Comments
 (0)