|
29 | 29 | is_torch_available, |
30 | 30 | ) |
31 | 31 | from transformers.testing_utils import ( |
| 32 | + Expectations, |
32 | 33 | cleanup, |
| 34 | + require_deterministic_for_xpu, |
33 | 35 | require_torch, |
34 | 36 | slow, |
35 | 37 | torch_device, |
@@ -345,16 +347,38 @@ def test_fixture_single_matches(self): |
345 | 347 | txt = self.processor.batch_decode(gen_ids, skip_special_tokens=True) |
346 | 348 | self.assertListEqual(txt, exp_txt) |
347 | 349 |
|
| 350 | + @require_deterministic_for_xpu |
348 | 351 | @slow |
349 | 352 | def test_fixture_batched_matches(self): |
350 | 353 | """ |
351 | 354 | reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/a3226a0ba25e51be84a4808a79b59257#file-reproducer_hf-py |
352 | 355 | """ |
353 | | - path = Path(__file__).parent.parent.parent / "fixtures/musicflamingo/expected_results_batched.json" |
354 | | - with open(path, "r", encoding="utf-8") as f: |
355 | | - raw = json.load(f) |
356 | | - exp_ids = torch.tensor(raw["token_ids"]) |
357 | | - exp_txt = raw["transcriptions"] |
| 356 | + # fmt: off |
| 357 | + exp_ids = Expectations( |
| 358 | + { |
| 359 | + ("cuda", None): torch.tensor([ |
| 360 | + [1986, 3754, 374, 458, 94509, 19461, 98875, 55964, 3528, 1163, 681, 55964, 11598, 55564, 429, 57843, 279, 9842, 3040, 55964, 263, 55964, 1782, 55964, 30449, 27235, 315, 11416, 19461, 98875, 448, 279, 68897, 11, 10581, 52760, 42898, 975, 14260, 315, 6481, 97431, 55964, 13573, 2591, 2420, 13, 220, 576, 8090], |
| 361 | + [334, 68043, 220, 16, 1019, 33648, 9287, 88828, 304, 51454, 11, 12711, 28347, 261, 304, 279, 3054, 11, 24353, 20783, 18707, 30789, 11, 22502, 4614, 389, 279, 49293, 271, 334, 68043, 220, 17, 1019, 26843, 2367, 98091, 389, 279, 39612, 11, 304, 17172, 582, 6950, 11, 14697, 41315, 311, 279], |
| 362 | + ]), |
| 363 | + ("xpu", None): torch.tensor([ |
| 364 | + [1986, 3754, 374, 458, 94509, 19461, 98875, 55964, 3528, 1163, 681, 55964, 11598, 55564, 429, 57843, 279, 9842, 3040, 55964, 263, 55964, 1782, 55964, 30449, 27235, 315, 11416, 19461, 98875, 448, 279, 68897, 11, 10581, 52760, 42898, 975, 14260, 315, 6481, 97431, 55964, 13573, 2591, 2420, 13, 220, 576, 8090], |
| 365 | + [334, 68043, 220, 16, 1019, 33648, 9287, 88828, 304, 51454, 11, 12711, 28347, 261, 304, 279, 3054, 11, 24353, 20783, 18707, 30789, 11, 22502, 4614, 389, 2518, 49293, 271, 334, 68043, 220, 17, 1019, 26843, 2367, 98091, 389, 279, 39612, 11, 304, 17172, 582, 6950, 11, 14697, 41315, 311, 279], |
| 366 | + ]), |
| 367 | + } |
| 368 | + ).get_expectation() |
| 369 | + exp_txt = Expectations( |
| 370 | + { |
| 371 | + ("cuda", None): [ |
| 372 | + "This track is an uplifting Eurodance‑style Trance‑Pop anthem that blends the driving four‑on‑the‑floor pulse of classic Eurodance with the soaring, melodic synth work typical of modern trance‑infused pop. The duration", |
| 373 | + "**Verse 1**\nMidnight cravings in bloom, lights flicker in the room, pepperoni dreams arise, pizza party on the skies\n\n**Verse 2**\nCheese melts on the crust, in flavor we trust, boxes stacked to the", |
| 374 | + ], |
| 375 | + ("xpu", None): [ |
| 376 | + "This track is an uplifting Eurodance‑style Trance‑Pop anthem that blends the driving four‑on‑the‑floor pulse of classic Eurodance with the soaring, melodic synth work typical of modern trance‑infused pop. The duration", |
| 377 | + "**Verse 1**\nMidnight cravings in bloom, lights flicker in the room, pepperoni dreams arise, pizza party on red skies\n\n**Verse 2**\nCheese melts on the crust, in flavor we trust, boxes stacked to the", |
| 378 | + ], |
| 379 | + } |
| 380 | + ).get_expectation() |
| 381 | + # fmt: on |
358 | 382 |
|
359 | 383 | conversations = [ |
360 | 384 | [ |
|
0 commit comments