Skip to content

Commit 029cd03

Browse files
wangkuiyichanglan
authored andcommitted
Make Bazel cover tensorflow-datasets and seqio
GitOrigin-RevId: 12dbda6
1 parent 0f98bf3 commit 029cd03

2 files changed

Lines changed: 122 additions & 94 deletions

File tree

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright © 2023 Apple Inc.
2+
3+
"""Tests for tf.data inputs that require GCS access."""
4+
5+
from typing import Optional
6+
7+
import jax
8+
import pytest
9+
import tensorflow_datasets as tfds
10+
from absl.testing import absltest, parameterized
11+
12+
from axlearn.common.config import config_for_function
13+
from axlearn.common.input_tf_data import (
14+
_infer_num_examples,
15+
_infer_num_shards,
16+
_maybe_shard_examples,
17+
tfds_dataset,
18+
tfds_read_config,
19+
)
20+
21+
22+
class TfdsGcsTest(parameterized.TestCase):
23+
"""Tests for TFDS functionality that require GCS access.
24+
25+
These tests require Google Cloud Storage authentication to access
26+
dataset metadata. They are separated from the main test file to
27+
allow Bazel to filter them at the target level.
28+
"""
29+
30+
@parameterized.parameters(
31+
("train", 1024), ("validation", 8), ("train[:512]", 1), ("invalid", None)
32+
)
33+
@pytest.mark.gs_login # For pytest compatibility
34+
def test_infer_num_shards(self, split: str, expected: Optional[int]):
35+
builder = tfds.builder("c4/en", try_gcs=True)
36+
self.assertEqual(_infer_num_shards(builder, split), expected)
37+
38+
@parameterized.parameters(
39+
("validation", 1043), ("test", 1063), ("test[:12]", 12), ("invalid", None)
40+
)
41+
@pytest.mark.gs_login # For pytest compatibility
42+
def test_infer_num_examples(self, split: str, expected: Optional[int]):
43+
builder = tfds.builder("glue/cola:2.0.0", try_gcs=True)
44+
self.assertEqual(_infer_num_examples(builder, split), expected)
45+
46+
@parameterized.parameters(
47+
("validation", 5, True, "even split"),
48+
("validation", 1044, False, "make copy for each host"),
49+
("validation", 1044, True, "raise value error"),
50+
("invalid", 5, True, "even split"),
51+
)
52+
@pytest.mark.gs_login
53+
def test_maybe_shard_examples(
54+
self, split: str, required_shards: int, is_training: bool, expected: str
55+
):
56+
dataset_name = "glue/cola:2.0.0"
57+
builder = tfds.builder(dataset_name, try_gcs=True)
58+
read_config = config_for_function(tfds_read_config).set(is_training=is_training)
59+
if expected == "raise value error":
60+
with self.assertRaises(ValueError):
61+
_ = _maybe_shard_examples(
62+
builder=builder,
63+
read_config=read_config,
64+
split=split,
65+
required_shards=required_shards,
66+
is_training=is_training,
67+
dataset_name=dataset_name,
68+
)
69+
else:
70+
per_process_split = _maybe_shard_examples(
71+
builder=builder,
72+
read_config=read_config,
73+
split=split,
74+
required_shards=required_shards,
75+
is_training=is_training,
76+
dataset_name=dataset_name,
77+
)
78+
if expected == "even split":
79+
shard_index = read_config.shard_index or jax.process_index()
80+
expected_split = tfds.even_splits(split, n=required_shards, drop_remainder=False)[
81+
shard_index
82+
]
83+
self.assertTrue(expected_split == per_process_split)
84+
elif expected == "make copy for each host":
85+
self.assertTrue(per_process_split == split)
86+
87+
@parameterized.parameters(
88+
("validation", True, "sentence", "foobar"),
89+
("test", True, "sentence", "barfoo"),
90+
("validation", False, "sentence", "bar bar"),
91+
("test", False, "sentence", "foo foo"),
92+
)
93+
@pytest.mark.gs_login
94+
def test_tfds_decoders(self, split: str, is_training: bool, field_name: str, expected: str):
95+
def tfds_custom_decoder() -> dict[str, tfds.decode.Decoder]:
96+
@tfds.decode.make_decoder()
97+
def replace_field_value(field_value, _):
98+
return field_value + expected
99+
100+
# pylint: disable=no-value-for-parameter
101+
return {field_name: replace_field_value()}
102+
103+
decoders = config_for_function(tfds_custom_decoder)
104+
105+
dataset_name = "glue/cola:2.0.0"
106+
source = config_for_function(tfds_dataset).set(
107+
dataset_name=dataset_name,
108+
split=split,
109+
is_training=is_training,
110+
train_shuffle_buffer_size=8,
111+
decoders=decoders,
112+
)
113+
ds = source.instantiate()
114+
115+
for input_batch in ds().take(5):
116+
assert expected in input_batch[field_name].numpy().decode(
117+
"UTF-8"
118+
), f"Missing {expected} string in {field_name} field"
119+
120+
121+
if __name__ == "__main__":
122+
absltest.main()

axlearn/common/input_tf_data_test.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import jax
1313
import numpy as np
14-
import pytest
1514
import seqio
1615
import tensorflow as tf
1716
import tensorflow_datasets as tfds
@@ -29,9 +28,6 @@
2928
DatasetToDatasetFn,
3029
Input,
3130
_infer_cardinality,
32-
_infer_num_examples,
33-
_infer_num_shards,
34-
_maybe_shard_examples,
3531
_pad_for_evaluation,
3632
_pad_logical_to_physical,
3733
add_static_fields,
@@ -279,96 +275,6 @@ def test_tfds_read_config_with_custom_sharding(self, num_shards, shard_index):
279275
self.assertEqual(read_config.input_context.num_input_pipelines, num_shards)
280276
self.assertEqual(read_config.input_context.input_pipeline_id, shard_index)
281277

282-
@parameterized.parameters(
283-
("train", 1024), ("validation", 8), ("train[:512]", 1), ("invalid", None)
284-
)
285-
@pytest.mark.gs_login # must annotate within @parameterized.parameters
286-
def test_infer_num_shards(self, split: str, expected: Optional[int]):
287-
builder = tfds.builder("c4/en", try_gcs=True)
288-
self.assertEqual(_infer_num_shards(builder, split), expected)
289-
290-
@parameterized.parameters(
291-
("validation", 1043), ("test", 1063), ("test[:12]", 12), ("invalid", None)
292-
)
293-
@pytest.mark.gs_login # must annotate within @parameterized.parameters
294-
def test_infer_num_examples(self, split: str, expected: Optional[int]):
295-
builder = tfds.builder("glue/cola:2.0.0", try_gcs=True)
296-
self.assertEqual(_infer_num_examples(builder, split), expected)
297-
298-
@parameterized.parameters(
299-
("validation", 5, True, "even split"),
300-
("validation", 1044, False, "make copy for each host"),
301-
("validation", 1044, True, "raise value error"),
302-
("invalid", 5, True, "even split"),
303-
)
304-
@pytest.mark.gs_login
305-
def test_maybe_shard_examples(
306-
self, split: str, required_shards: int, is_training: bool, expected: str
307-
):
308-
dataset_name = "glue/cola:2.0.0"
309-
builder = tfds.builder(dataset_name, try_gcs=True)
310-
read_config = config_for_function(tfds_read_config).set(is_training=is_training)
311-
if expected == "raise value error":
312-
with self.assertRaises(ValueError):
313-
_ = _maybe_shard_examples(
314-
builder=builder,
315-
read_config=read_config,
316-
split=split,
317-
required_shards=required_shards,
318-
is_training=is_training,
319-
dataset_name=dataset_name,
320-
)
321-
else:
322-
per_process_split = _maybe_shard_examples(
323-
builder=builder,
324-
read_config=read_config,
325-
split=split,
326-
required_shards=required_shards,
327-
is_training=is_training,
328-
dataset_name=dataset_name,
329-
)
330-
if expected == "even split":
331-
shard_index = read_config.shard_index or jax.process_index()
332-
expected_split = tfds.even_splits(split, n=required_shards, drop_remainder=False)[
333-
shard_index
334-
]
335-
self.assertTrue(expected_split == per_process_split)
336-
elif expected == "make copy for each host":
337-
self.assertTrue(per_process_split == split)
338-
339-
@parameterized.parameters(
340-
("validation", True, "sentence", "foobar"),
341-
("test", True, "sentence", "barfoo"),
342-
("validation", False, "sentence", "bar bar"),
343-
("test", False, "sentence", "foo foo"),
344-
)
345-
@pytest.mark.gs_login
346-
def test_tfds_decoders(self, split: str, is_training: bool, field_name: str, expected: str):
347-
def tfds_custom_decoder() -> dict[str, tfds.decode.Decoder]:
348-
@tfds.decode.make_decoder()
349-
def replace_field_value(field_value, _):
350-
return field_value + expected
351-
352-
# pylint: disable=no-value-for-parameter
353-
return {field_name: replace_field_value()}
354-
355-
decoders = config_for_function(tfds_custom_decoder)
356-
357-
dataset_name = "glue/cola:2.0.0"
358-
source = config_for_function(tfds_dataset).set(
359-
dataset_name=dataset_name,
360-
split=split,
361-
is_training=is_training,
362-
train_shuffle_buffer_size=8,
363-
decoders=decoders,
364-
)
365-
ds = source.instantiate()
366-
367-
for input_batch in ds().take(5):
368-
assert expected in input_batch[field_name].numpy().decode(
369-
"UTF-8"
370-
), f"Missing {expected} string in {field_name} field"
371-
372278
@parameterized.parameters(
373279
("inputs_pretokenized"),
374280
("prefix_ids"),

0 commit comments

Comments
 (0)