|
| 1 | +from datetime import UTC, datetime |
| 2 | +from pathlib import Path |
| 3 | +from unittest.mock import MagicMock, Mock, patch |
| 4 | + |
| 5 | +from infer_check.types import SweepResult |
| 6 | + |
| 7 | + |
| 8 | +def test_sweep_model_parsing_robustness() -> None: |
| 9 | + """Test that sweep command parses model paths robustly, handling extra equals signs.""" |
| 10 | + # Create a mock SweepResult to return from runner.sweep |
| 11 | + mock_sweep_result = SweepResult( |
| 12 | + model_id="test-model", |
| 13 | + backend_name="test-backend", |
| 14 | + quantization_levels=["bf16", "4bit"], |
| 15 | + comparisons=[], |
| 16 | + timestamp=datetime.now(UTC), |
| 17 | + summary={}, |
| 18 | + ) |
| 19 | + |
| 20 | + # We mock get_backend_for_model and TestRunner.sweep to avoid actual initialization |
| 21 | + with ( |
| 22 | + patch("infer_check.backends.base.get_backend_for_model") as mock_get_backend, |
| 23 | + patch("infer_check.runner.TestRunner.sweep", new_callable=Mock), |
| 24 | + patch("infer_check.suites.loader.load_suite", return_value=[MagicMock()]), |
| 25 | + patch("infer_check.cli._resolve_prompts", return_value=Path("dummy.jsonl")), |
| 26 | + patch("asyncio.run", return_value=mock_sweep_result), |
| 27 | + ): |
| 28 | + mock_get_backend.return_value.name = "test-backend" |
| 29 | + # Simulating the command: infer-check sweep --models "bf16==path/to/model" --prompts dummy |
| 30 | + # We call the function directly as click command |
| 31 | + from click.testing import CliRunner |
| 32 | + |
| 33 | + from infer_check.cli import main |
| 34 | + |
| 35 | + runner = CliRunner() |
| 36 | + # Using a subset of arguments to trigger the parsing logic |
| 37 | + with runner.isolated_filesystem(): |
| 38 | + result = runner.invoke( |
| 39 | + main, ["sweep", "--models", "bf16==bartowski/Qwen,4bit=bartowski/Qwen", "--prompts", "reasoning"] |
| 40 | + ) |
| 41 | + assert result.exit_code == 0, result.output |
| 42 | + |
| 43 | + # Check if get_backend_for_model was called with cleaned paths |
| 44 | + # It should be called twice: once for bf16 and once for 4bit |
| 45 | + assert mock_get_backend.call_count == 2 |
| 46 | + |
| 47 | + # Check first call (bf16) |
| 48 | + args, kwargs = mock_get_backend.call_args_list[0] |
| 49 | + assert kwargs["model_str"] == "bartowski/Qwen" |
| 50 | + assert kwargs["quantization"] == "bf16" |
| 51 | + |
| 52 | + # Check second call (4bit) |
| 53 | + args, kwargs = mock_get_backend.call_args_list[1] |
| 54 | + assert kwargs["model_str"] == "bartowski/Qwen" |
| 55 | + assert kwargs["quantization"] == "4bit" |
0 commit comments