Skip to content

Commit 27d94a0

Browse files
committed
style: apply ruff formatting to test files
1 parent af11c60 commit 27d94a0

2 files changed

Lines changed: 15 additions & 15 deletions

File tree

tests/test_base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def test_model_has_forward_method(model_class):
5454
assert hasattr(model_class, "forward"), f"{model_class.__name__} is missing a forward method."
5555

5656
sig = inspect.signature(model_class.forward)
57-
assert any(
58-
p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()
59-
), f"{model_class.__name__}.forward should have *data argument."
57+
assert any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()), (
58+
f"{model_class.__name__}.forward should have *data argument."
59+
)
6060

6161

6262
@pytest.mark.parametrize("model_class", model_classes)
@@ -77,9 +77,9 @@ def test_model_has_num_classes(model_class):
7777
def test_model_calls_super_init(model_class):
7878
"""Test that each model calls super().__init__(config=config, **kwargs)."""
7979
source = inspect.getsource(model_class.__init__)
80-
assert (
81-
"super().__init__(config=config" in source
82-
), f"{model_class.__name__} should call super().__init__(config=config, **kwargs)."
80+
assert "super().__init__(config=config" in source, (
81+
f"{model_class.__name__} should call super().__init__(config=config, **kwargs)."
82+
)
8383

8484

8585
@pytest.mark.parametrize("model_class", model_classes)

tests/test_configs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ def test_config_default_values(config_class):
6666
if origin is typing.Literal:
6767
# If the field is a Literal, ensure the value is one of the allowed options
6868
allowed_values = typing.get_args(expected_type)
69-
assert (
70-
value in allowed_values
71-
), f"{config_class.__name__}.{attr} has incorrect value: expected one of {allowed_values}, got {value}"
69+
assert value in allowed_values, (
70+
f"{config_class.__name__}.{attr} has incorrect value: expected one of {allowed_values}, got {value}"
71+
)
7272
elif origin is typing.Union:
7373
# For Union types (e.g., Optional[str]), check if value matches any type in the union
7474
allowed_types = typing.get_args(expected_type)
75-
assert any(
76-
isinstance(value, t) for t in allowed_types
77-
), f"{config_class.__name__}.{attr} has incorrect type: expected one of {allowed_types}, got {type(value)}"
75+
assert any(isinstance(value, t) for t in allowed_types), (
76+
f"{config_class.__name__}.{attr} has incorrect type: expected one of {allowed_types}, got {type(value)}"
77+
)
7878
elif origin is not None:
7979
# If it's another generic type (e.g., list[str]), check against the base type
80-
assert (
81-
isinstance(value, origin) or value is None
82-
), f"{config_class.__name__}.{attr} has incorrect type: expected {expected_type}, got {type(value)}"
80+
assert isinstance(value, origin) or value is None, (
81+
f"{config_class.__name__}.{attr} has incorrect type: expected {expected_type}, got {type(value)}"
82+
)
8383
else:
8484
# Standard type check
8585
assert (

0 commit comments

Comments
 (0)