|
1 | 1 | """Tests for the CodeCarbon CLI main function.""" |
2 | 2 |
|
| 3 | +from types import SimpleNamespace |
| 4 | + |
| 5 | +import pytest |
3 | 6 | from typer.testing import CliRunner |
4 | 7 |
|
5 | 8 | from codecarbon.cli import main as cli_main |
@@ -112,3 +115,183 @@ def fake_get_access_token(): |
112 | 115 | captured = capsys.readouterr() |
113 | 116 | assert "Could not validate remote configuration details" in captured.out |
114 | 117 | assert "Not able to retrieve the access token" in captured.out |
| 118 | + |
| 119 | + |
| 120 | +def test_main_exits_with_error_when_command_raises(monkeypatch, capsys): |
| 121 | + def fake_cli(): |
| 122 | + raise RuntimeError("boom") |
| 123 | + |
| 124 | + monkeypatch.setattr(cli_main, "codecarbon", fake_cli) |
| 125 | + |
| 126 | + with pytest.raises(SystemExit) as exc_info: |
| 127 | + cli_main.main() |
| 128 | + |
| 129 | + captured = capsys.readouterr() |
| 130 | + assert exc_info.value.code == 1 |
| 131 | + assert "Error:" in captured.out |
| 132 | + assert "boom" in captured.out |
| 133 | + |
| 134 | + |
| 135 | +def test_login_calls_authorize_and_auth_check(monkeypatch): |
| 136 | + calls = {"authorize": 0, "set_token": None, "check_auth": 0} |
| 137 | + |
| 138 | + class FakeApiClient: |
| 139 | + def __init__(self, endpoint_url=None): |
| 140 | + self.endpoint_url = endpoint_url |
| 141 | + |
| 142 | + def set_access_token(self, token): |
| 143 | + calls["set_token"] = token |
| 144 | + |
| 145 | + def check_auth(self): |
| 146 | + calls["check_auth"] += 1 |
| 147 | + |
| 148 | + monkeypatch.setattr(cli_main, "ApiClient", FakeApiClient) |
| 149 | + monkeypatch.setattr( |
| 150 | + cli_main, |
| 151 | + "authorize", |
| 152 | + lambda: calls.__setitem__("authorize", calls["authorize"] + 1), |
| 153 | + ) |
| 154 | + monkeypatch.setattr(cli_main, "get_access_token", lambda: "login-token") |
| 155 | + |
| 156 | + runner = CliRunner() |
| 157 | + result = runner.invoke(cli_main.codecarbon, ["login"]) |
| 158 | + assert result.exit_code == 0 |
| 159 | + assert calls["authorize"] == 1 |
| 160 | + assert calls["set_token"] == "login-token" |
| 161 | + assert calls["check_auth"] == 1 |
| 162 | + |
| 163 | + |
| 164 | +def test_get_api_key_uses_bearer_token(monkeypatch): |
| 165 | + captured = {} |
| 166 | + |
| 167 | + class FakeResponse: |
| 168 | + def json(self): |
| 169 | + return {"token": "project-api-token"} |
| 170 | + |
| 171 | + def fake_post(url, json, headers): |
| 172 | + captured["url"] = url |
| 173 | + captured["json"] = json |
| 174 | + captured["headers"] = headers |
| 175 | + return FakeResponse() |
| 176 | + |
| 177 | + monkeypatch.setattr(cli_main, "get_access_token", lambda: "access-token") |
| 178 | + monkeypatch.setattr(cli_main.requests, "post", fake_post) |
| 179 | + |
| 180 | + token = cli_main.get_api_key("proj-123") |
| 181 | + assert token == "project-api-token" |
| 182 | + assert captured["url"].endswith("/projects/proj-123/api-tokens") |
| 183 | + assert captured["json"]["project_id"] == "proj-123" |
| 184 | + assert captured["headers"]["Authorization"] == "Bearer access-token" |
| 185 | + |
| 186 | + |
| 187 | +def test_get_token_command_prints_token(monkeypatch): |
| 188 | + monkeypatch.setattr(cli_main, "get_api_key", lambda project_id: "abc123") |
| 189 | + runner = CliRunner() |
| 190 | + result = runner.invoke(cli_main.codecarbon, ["get-token", "proj-id"]) |
| 191 | + assert result.exit_code == 0 |
| 192 | + assert "Your token: abc123" in result.output |
| 193 | + |
| 194 | + |
| 195 | +def test_show_config_prints_missing_project_and_experiment( |
| 196 | + monkeypatch, tmp_path, capsys |
| 197 | +): |
| 198 | + class FakeApiClient: |
| 199 | + def __init__(self, endpoint_url=None): |
| 200 | + self.endpoint_url = endpoint_url |
| 201 | + |
| 202 | + def set_access_token(self, token): |
| 203 | + self.token = token |
| 204 | + |
| 205 | + def get_organization(self, organization_id): |
| 206 | + return {"id": organization_id} |
| 207 | + |
| 208 | + def get_project(self, project_id): |
| 209 | + return {"id": project_id} |
| 210 | + |
| 211 | + def get_experiment(self, experiment_id): |
| 212 | + return {"id": experiment_id} |
| 213 | + |
| 214 | + monkeypatch.setattr(cli_main, "ApiClient", FakeApiClient) |
| 215 | + monkeypatch.setattr(cli_main, "get_access_token", lambda: "fake-token") |
| 216 | + monkeypatch.setattr( |
| 217 | + cli_main, "get_api_endpoint", lambda path: "https://api.codecarbon.io" |
| 218 | + ) |
| 219 | + |
| 220 | + monkeypatch.setattr( |
| 221 | + cli_main, |
| 222 | + "get_config", |
| 223 | + lambda path: { |
| 224 | + "api_endpoint": "https://api.codecarbon.io", |
| 225 | + "organization_id": "org-id", |
| 226 | + }, |
| 227 | + ) |
| 228 | + cli_main.show_config(tmp_path / ".codecarbon.config") |
| 229 | + captured = capsys.readouterr() |
| 230 | + assert "No project_id in config" in captured.out |
| 231 | + |
| 232 | + monkeypatch.setattr( |
| 233 | + cli_main, |
| 234 | + "get_config", |
| 235 | + lambda path: { |
| 236 | + "api_endpoint": "https://api.codecarbon.io", |
| 237 | + "organization_id": "org-id", |
| 238 | + "project_id": "project-id", |
| 239 | + }, |
| 240 | + ) |
| 241 | + cli_main.show_config(tmp_path / ".codecarbon.config") |
| 242 | + captured = capsys.readouterr() |
| 243 | + assert "No experiment_id in config" in captured.out |
| 244 | + |
| 245 | + |
| 246 | +def test_monitor_online_requires_experiment_id(monkeypatch): |
| 247 | + monkeypatch.setattr(cli_main, "get_existing_exp_id", lambda: None) |
| 248 | + runner = CliRunner() |
| 249 | + result = runner.invoke(cli_main.codecarbon, ["monitor"]) |
| 250 | + assert result.exit_code == 1 |
| 251 | + assert "No experiment id" in result.output |
| 252 | + |
| 253 | + |
| 254 | +def test_monitor_offline_initializes_offline_tracker(monkeypatch): |
| 255 | + calls = {"kwargs": None, "started": 0} |
| 256 | + |
| 257 | + class FakeOfflineTracker: |
| 258 | + def __init__(self, **kwargs): |
| 259 | + calls["kwargs"] = kwargs |
| 260 | + self._another_instance_already_running = True |
| 261 | + |
| 262 | + def start(self): |
| 263 | + calls["started"] += 1 |
| 264 | + |
| 265 | + def stop(self): |
| 266 | + return None |
| 267 | + |
| 268 | + monkeypatch.setattr(cli_main, "OfflineEmissionsTracker", FakeOfflineTracker) |
| 269 | + monkeypatch.setattr(cli_main.signal, "signal", lambda *args, **kwargs: None) |
| 270 | + |
| 271 | + runner = CliRunner() |
| 272 | + result = runner.invoke( |
| 273 | + cli_main.codecarbon, |
| 274 | + ["monitor", "--offline", "--country-iso-code", "FRA", "--region", "IDF"], |
| 275 | + ) |
| 276 | + assert result.exit_code == 0 |
| 277 | + assert calls["started"] == 1 |
| 278 | + assert calls["kwargs"]["country_iso_code"] == "FRA" |
| 279 | + assert calls["kwargs"]["region"] == "IDF" |
| 280 | + |
| 281 | + |
| 282 | +def test_monitor_delegates_to_run_and_monitor_with_extra_args(monkeypatch): |
| 283 | + captured = {} |
| 284 | + |
| 285 | + def fake_run_and_monitor(ctx, **kwargs): |
| 286 | + captured["args"] = list(ctx.args) |
| 287 | + captured["kwargs"] = kwargs |
| 288 | + return "ok" |
| 289 | + |
| 290 | + monkeypatch.setattr(cli_main, "run_and_monitor", fake_run_and_monitor) |
| 291 | + monkeypatch.setattr(cli_main, "get_existing_exp_id", lambda: "exp-1") |
| 292 | + |
| 293 | + ctx = SimpleNamespace(args=["python", "train.py"]) |
| 294 | + result = cli_main.monitor(ctx=ctx, api=False) |
| 295 | + assert result == "ok" |
| 296 | + assert captured["args"] == ["python", "train.py"] |
| 297 | + assert captured["kwargs"]["save_to_api"] is False |
0 commit comments