|
| 1 | +# Copyright © 2023 Apple Inc. |
| 2 | + |
| 3 | +"""Tests for tf.data inputs that require GCS access.""" |
| 4 | + |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import jax |
| 8 | +import pytest |
| 9 | +import tensorflow_datasets as tfds |
| 10 | +from absl.testing import absltest, parameterized |
| 11 | + |
| 12 | +from axlearn.common.config import config_for_function |
| 13 | +from axlearn.common.input_tf_data import ( |
| 14 | + _infer_num_examples, |
| 15 | + _infer_num_shards, |
| 16 | + _maybe_shard_examples, |
| 17 | + tfds_dataset, |
| 18 | + tfds_read_config, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +class TfdsGcsTest(parameterized.TestCase): |
| 23 | + """Tests for TFDS functionality that require GCS access. |
| 24 | +
|
| 25 | + These tests require Google Cloud Storage authentication to access |
| 26 | + dataset metadata. They are separated from the main test file to |
| 27 | + allow Bazel to filter them at the target level. |
| 28 | + """ |
| 29 | + |
| 30 | + @parameterized.parameters( |
| 31 | + ("train", 1024), ("validation", 8), ("train[:512]", 1), ("invalid", None) |
| 32 | + ) |
| 33 | + @pytest.mark.gs_login # For pytest compatibility |
| 34 | + def test_infer_num_shards(self, split: str, expected: Optional[int]): |
| 35 | + builder = tfds.builder("c4/en", try_gcs=True) |
| 36 | + self.assertEqual(_infer_num_shards(builder, split), expected) |
| 37 | + |
| 38 | + @parameterized.parameters( |
| 39 | + ("validation", 1043), ("test", 1063), ("test[:12]", 12), ("invalid", None) |
| 40 | + ) |
| 41 | + @pytest.mark.gs_login # For pytest compatibility |
| 42 | + def test_infer_num_examples(self, split: str, expected: Optional[int]): |
| 43 | + builder = tfds.builder("glue/cola:2.0.0", try_gcs=True) |
| 44 | + self.assertEqual(_infer_num_examples(builder, split), expected) |
| 45 | + |
| 46 | + @parameterized.parameters( |
| 47 | + ("validation", 5, True, "even split"), |
| 48 | + ("validation", 1044, False, "make copy for each host"), |
| 49 | + ("validation", 1044, True, "raise value error"), |
| 50 | + ("invalid", 5, True, "even split"), |
| 51 | + ) |
| 52 | + @pytest.mark.gs_login |
| 53 | + def test_maybe_shard_examples( |
| 54 | + self, split: str, required_shards: int, is_training: bool, expected: str |
| 55 | + ): |
| 56 | + dataset_name = "glue/cola:2.0.0" |
| 57 | + builder = tfds.builder(dataset_name, try_gcs=True) |
| 58 | + read_config = config_for_function(tfds_read_config).set(is_training=is_training) |
| 59 | + if expected == "raise value error": |
| 60 | + with self.assertRaises(ValueError): |
| 61 | + _ = _maybe_shard_examples( |
| 62 | + builder=builder, |
| 63 | + read_config=read_config, |
| 64 | + split=split, |
| 65 | + required_shards=required_shards, |
| 66 | + is_training=is_training, |
| 67 | + dataset_name=dataset_name, |
| 68 | + ) |
| 69 | + else: |
| 70 | + per_process_split = _maybe_shard_examples( |
| 71 | + builder=builder, |
| 72 | + read_config=read_config, |
| 73 | + split=split, |
| 74 | + required_shards=required_shards, |
| 75 | + is_training=is_training, |
| 76 | + dataset_name=dataset_name, |
| 77 | + ) |
| 78 | + if expected == "even split": |
| 79 | + shard_index = read_config.shard_index or jax.process_index() |
| 80 | + expected_split = tfds.even_splits(split, n=required_shards, drop_remainder=False)[ |
| 81 | + shard_index |
| 82 | + ] |
| 83 | + self.assertTrue(expected_split == per_process_split) |
| 84 | + elif expected == "make copy for each host": |
| 85 | + self.assertTrue(per_process_split == split) |
| 86 | + |
| 87 | + @parameterized.parameters( |
| 88 | + ("validation", True, "sentence", "foobar"), |
| 89 | + ("test", True, "sentence", "barfoo"), |
| 90 | + ("validation", False, "sentence", "bar bar"), |
| 91 | + ("test", False, "sentence", "foo foo"), |
| 92 | + ) |
| 93 | + @pytest.mark.gs_login |
| 94 | + def test_tfds_decoders(self, split: str, is_training: bool, field_name: str, expected: str): |
| 95 | + def tfds_custom_decoder() -> dict[str, tfds.decode.Decoder]: |
| 96 | + @tfds.decode.make_decoder() |
| 97 | + def replace_field_value(field_value, _): |
| 98 | + return field_value + expected |
| 99 | + |
| 100 | + # pylint: disable=no-value-for-parameter |
| 101 | + return {field_name: replace_field_value()} |
| 102 | + |
| 103 | + decoders = config_for_function(tfds_custom_decoder) |
| 104 | + |
| 105 | + dataset_name = "glue/cola:2.0.0" |
| 106 | + source = config_for_function(tfds_dataset).set( |
| 107 | + dataset_name=dataset_name, |
| 108 | + split=split, |
| 109 | + is_training=is_training, |
| 110 | + train_shuffle_buffer_size=8, |
| 111 | + decoders=decoders, |
| 112 | + ) |
| 113 | + ds = source.instantiate() |
| 114 | + |
| 115 | + for input_batch in ds().take(5): |
| 116 | + assert expected in input_batch[field_name].numpy().decode( |
| 117 | + "UTF-8" |
| 118 | + ), f"Missing {expected} string in {field_name} field" |
| 119 | + |
| 120 | + |
| 121 | +if __name__ == "__main__": |
| 122 | + absltest.main() |
0 commit comments