-
Notifications
You must be signed in to change notification settings - Fork 122
Expand file tree
/
Copy pathtest_utils.py
More file actions
80 lines (62 loc) · 2.41 KB
/
Copy pathtest_utils.py
File metadata and controls
80 lines (62 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations
import json
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any
from unittest.mock import patch
import numpy as np
import pytest
import safetensors
import safetensors.numpy
from tokenizers import Tokenizer
from model2vec.distill.utils import select_optimal_device
from model2vec.hf_utils import _get_metadata_from_readme
from model2vec.utils import get_package_extras, importable
def test__get_metadata_from_readme_not_exists() -> None:
"""Test getting metadata from a README."""
assert _get_metadata_from_readme(Path("zzz")) == {}
def test__get_metadata_from_readme_mocked_file() -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"---\nkey: value\n---\n")
f.flush()
assert _get_metadata_from_readme(Path(f.name))["key"] == "value"
def test__get_metadata_from_readme_mocked_file_keys() -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"")
f.flush()
assert set(_get_metadata_from_readme(Path(f.name))) == set()
@pytest.mark.parametrize(
"device, expected, cuda, mps",
[
("cpu", "cpu", True, True),
("cpu", "cpu", True, False),
("cpu", "cpu", False, True),
("cpu", "cpu", False, False),
("clown", "clown", False, False),
(None, "cuda", True, True),
(None, "cuda", True, False),
(None, "mps", False, True),
(None, "cpu", False, False),
],
)
def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None:
"""Test whether the optimal device is selected."""
with (
patch("torch.cuda.is_available", return_value=cuda),
patch("torch.backends.mps.is_available", return_value=mps),
):
assert select_optimal_device(device) == expected
def test_importable() -> None:
"""Test the importable function."""
with pytest.raises(ImportError):
importable("clown", "clown")
importable("os", "clown")
def test_get_package_extras() -> None:
"""Test package extras."""
extras = set(get_package_extras("model2vec", "distill"))
assert extras == {"torch", "transformers", "scikit-learn"}
def test_get_package_extras_empty() -> None:
"""Test package extras with an empty package."""
assert not list(get_package_extras("tqdm", ""))