Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/models/exaone4_5/test_modeling_exaone4_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers.testing_utils import (
Expectations,
cleanup,
require_deterministic_for_xpu,
require_torch,
slow,
torch_device,
Expand Down Expand Up @@ -164,6 +165,7 @@ def setUpClass(cls):
def tearDown(self):
cleanup(torch_device, gc_collect=True)

@require_deterministic_for_xpu
@slow
def test_model_logits(self):
input_ids = [70045, 1109, 115406, 16943, 11697, 115365, 19816, 12137, 375]
Expand All @@ -177,19 +179,26 @@ def test_model_logits(self):
("cuda", (8, 6)): torch.tensor(
[[44.8527, 45.7216, 71.1159, 36.9564, 44.3283, 22.0527, 28.3233, 62.5739, 46.0708]]
),
("xpu", None): torch.tensor(
[[45.2173, 45.4939, 71.0896, 37.1218, 44.3504, 22.1194, 28.6795, 62.5956, 45.9839]]
),
}
)
EXPECTED_SLICE = Expectations(
{
("cuda", (8, 6)): torch.tensor(
[42.2500, 43.0000, 42.5000, 44.7500, 49.5000, 46.0000, 46.5000, 46.5000, 45.7500, 46.2500]
),
("xpu", None): torch.tensor(
[42.7500, 43.5000, 42.7500, 45.2500, 50.0000, 46.5000, 46.7500, 46.7500, 46.0000, 46.5000]
),
}
)

torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN.get_expectation(), atol=1e-2, rtol=1e-2)
torch.testing.assert_close(out[0, 0, :10], EXPECTED_SLICE.get_expectation(), atol=1e-4, rtol=1e-4)

@require_deterministic_for_xpu
@slow
def test_model_generation_text_only(self):
EXPECTED_TEXT = Expectations(
Expand All @@ -198,6 +207,10 @@ def test_model_generation_text_only(self):
'\nTell me about the Miracle on the Han river.\n\n<think>\n\n</think>\n\nThe **"Miracle on the Han River"**'
" is a term used to describe the rapid economic development and industrialization that South Korea experienced"
),
("xpu", None): (
'\nTell me about the Miracle on the Han river.\n\n<think>\n\n</think>\n\nThe **"Miracle on the Han River"**'
" is a term used to describe the rapid economic development and industrialization that South Korea experienced"
),
}
)
messages = [
Expand All @@ -215,6 +228,7 @@ def test_model_generation_text_only(self):
text = self.processor.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(text, EXPECTED_TEXT.get_expectation())

@require_deterministic_for_xpu
@slow
def test_model_generation_image_text(self):
IMAGE_URL = (
Expand All @@ -225,6 +239,9 @@ def test_model_generation_image_text(self):
("cuda", 8): (
"\n\nDescribe the image.\n\n<think>\n\n</think>\n\nThe image captures a fluffy, young lynx kitten walking across a snowy surface, its thick"
),
("xpu", 3): (
"\n\nDescribe the image.\n\n<think>\n\n</think>\n\nThe image captures a young, fluffy wild cat—likely a lynx kitten—walking through a"
),
}
)
messages = [
Expand Down
Loading