|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import base64 |
| 5 | +import json |
4 | 6 | import re |
5 | 7 | from html.parser import HTMLParser |
6 | 8 | from pathlib import Path |
7 | 9 | from urllib.parse import urljoin |
8 | 10 |
|
| 11 | +import pytest |
9 | 12 | import requests |
10 | 13 |
|
11 | 14 | from diracx import cli |
12 | 15 |
|
13 | 16 |
|
14 | | -def do_device_flow_with_dex(url: str) -> None: |
15 | | - """Do the device flow with dex""" |
| 17 | +async def test_login(monkeypatch, capfd, cli_env): |
| 18 | + """Test that the CLI can login successfully""" |
| 19 | + expected_credentials_path = Path( |
| 20 | + cli_env["HOME"], ".cache", "diracx", "credentials.json" |
| 21 | + ) |
16 | 22 |
|
17 | | - class DexLoginFormParser(HTMLParser): |
18 | | - def handle_starttag(self, tag, attrs): |
19 | | - nonlocal action_url |
20 | | - if "form" in str(tag): |
21 | | - assert action_url is None |
22 | | - action_url = urljoin(login_page_url, dict(attrs)["action"]) |
| 23 | + # Ensure the credentials file does not exist before logging in |
| 24 | + assert not expected_credentials_path.exists() |
23 | 25 |
|
24 | | - # Get the login page |
25 | | - r = requests.get(url) |
26 | | - r.raise_for_status() |
27 | | - login_page_url = r.url # This is not the same as URL as we redirect to dex |
28 | | - login_page_body = r.text |
| 26 | + # Do the actual login |
| 27 | + await do_successful_login(monkeypatch, capfd, cli_env) |
29 | 28 |
|
30 | | - # Search the page for the login form so we know where to post the credentials |
31 | | - action_url = None |
32 | | - DexLoginFormParser().feed(login_page_body) |
33 | | - assert action_url is not None, login_page_body |
| 29 | + # Ensure the credentials file exists after logging in |
| 30 | + assert expected_credentials_path.exists() |
| 31 | + |
| 32 | + |
| 33 | +async def test_invalid_credentials_file(monkeypatch, capfd, cli_env): |
| 34 | + """Test that the CLI can handle an invalid credentials file""" |
| 35 | + expected_credentials_path = Path( |
| 36 | + cli_env["HOME"], ".cache", "diracx", "credentials.json" |
| 37 | + ) |
| 38 | + expected_credentials_path.parent.mkdir(parents=True, exist_ok=True) |
| 39 | + expected_credentials_path.write_text("invalid json") |
34 | 40 |
|
35 | 41 | # Do the actual login |
36 | | - r = requests.post( |
37 | | - action_url, data={"login": "admin@example.com", "password": "password"} |
| 42 | + await do_successful_login(monkeypatch, capfd, cli_env) |
| 43 | + |
| 44 | + |
| 45 | +async def test_invalid_access_token(cli_env, monkeypatch, capfd, with_cli_login): |
| 46 | + """Test that the CLI can handle an invalid access token |
| 47 | +
|
| 48 | + We expect the CLI to detect the invalid access token and use the refresh |
| 49 | + token to get a new access token without prompting the user to login again. |
| 50 | + """ |
| 51 | + expected_credentials_path = Path( |
| 52 | + cli_env["HOME"], ".cache", "diracx", "credentials.json" |
38 | 53 | ) |
39 | | - r.raise_for_status() |
40 | | - # This should have redirected to the DiracX page that shows the login is complete |
41 | | - assert "Please close the window" in r.text |
42 | 54 |
|
| 55 | + credentials = json.loads(expected_credentials_path.read_text()) |
| 56 | + bad_credentials = credentials | { |
| 57 | + "access_token": make_invalid_jwt(credentials["access_token"]), |
| 58 | + "expires_on": credentials["expires_on"] - 3600, |
| 59 | + } |
| 60 | + expected_credentials_path.write_text(json.dumps(bad_credentials)) |
43 | 61 |
|
44 | | -async def test_login(monkeypatch, capfd, cli_env): |
| 62 | + # See if the credentials still work |
| 63 | + await cli.whoami() |
| 64 | + cap = capfd.readouterr() |
| 65 | + assert cap.err == "" |
| 66 | + assert json.loads(cap.out)["vo"] == "diracAdmin" |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.xfail(reason="TODO: Implement nicer error handling in the CLI") |
| 70 | +async def test_invalid_refresh_token(cli_env, monkeypatch, capfd, with_cli_login): |
| 71 | + """Test that the CLI can handle an invalid refresh token |
| 72 | +
|
| 73 | + We expect the CLI to detect the invalid refresh token and prompt the user |
| 74 | + to login again. |
| 75 | + """ |
| 76 | + expected_credentials_path = Path( |
| 77 | + cli_env["HOME"], ".cache", "diracx", "credentials.json" |
| 78 | + ) |
| 79 | + |
| 80 | + credentials = json.loads(expected_credentials_path.read_text()) |
| 81 | + bad_credentials = credentials | { |
| 82 | + "refresh_token": make_invalid_jwt(credentials["refresh_token"]), |
| 83 | + "expires_on": credentials["expires_on"] - 3600, |
| 84 | + } |
| 85 | + expected_credentials_path.write_text(json.dumps(bad_credentials)) |
| 86 | + |
| 87 | + with pytest.raises(SystemExit): |
| 88 | + await cli.whoami() |
| 89 | + cap = capfd.readouterr() |
| 90 | + assert cap.out == "" |
| 91 | + assert "dirac login" in cap.err |
| 92 | + |
| 93 | + # Having invalid credentials should prompt the user to login again |
| 94 | + await do_successful_login(monkeypatch, capfd, cli_env) |
| 95 | + |
| 96 | + # See if the credentials work |
| 97 | + await cli.whoami() |
| 98 | + cap = capfd.readouterr() |
| 99 | + assert cap.err == "" |
| 100 | + assert json.loads(cap.out)["vo"] == "diracAdmin" |
| 101 | + |
| 102 | + |
| 103 | +# ############################################### |
| 104 | +# The rest of this file contains helper functions |
| 105 | +# ############################################### |
| 106 | + |
| 107 | + |
| 108 | +async def do_successful_login(monkeypatch, capfd, cli_env): |
| 109 | + """Do a successful login using the CLI""" |
45 | 110 | poll_attempts = 0 |
46 | 111 |
|
47 | 112 | def fake_sleep(*args, **kwargs): |
@@ -69,24 +134,59 @@ def fake_sleep(*args, **kwargs): |
69 | 134 | # would normally be done by a user. This includes capturing the login URL |
70 | 135 | # and doing the actual device flow with dex. |
71 | 136 | unpatched_sleep = asyncio.sleep |
72 | | - monkeypatch.setattr("asyncio.sleep", fake_sleep) |
73 | | - |
74 | | - expected_credentials_path = Path( |
75 | | - cli_env["HOME"], ".cache", "diracx", "credentials.json" |
76 | | - ) |
| 137 | + with monkeypatch.context() as m: |
| 138 | + m.setattr("asyncio.sleep", fake_sleep) |
77 | 139 |
|
78 | | - # Ensure the credentials file does not exist before logging in |
79 | | - assert not expected_credentials_path.exists() |
| 140 | + # Run the login command |
| 141 | + await cli.login(vo="diracAdmin", group=None, property=None) |
80 | 142 |
|
81 | | - # Run the login command |
82 | | - await cli.login(vo="diracAdmin", group=None, property=None) |
83 | 143 | captured = capfd.readouterr() |
84 | 144 | assert "Login successful!" in captured.out |
85 | 145 | assert captured.err == "" |
86 | 146 |
|
87 | | - # Ensure the credentials file exists after logging in |
88 | | - assert expected_credentials_path.exists() |
89 | 147 |
|
90 | | - # Return the credentials so this test can also be used by the |
91 | | - # "with_cli_login" fixture |
92 | | - return expected_credentials_path.read_text() |
| 148 | +def do_device_flow_with_dex(url: str) -> None: |
| 149 | + """Do the device flow with dex""" |
| 150 | + |
| 151 | + class DexLoginFormParser(HTMLParser): |
| 152 | + def handle_starttag(self, tag, attrs): |
| 153 | + nonlocal action_url |
| 154 | + if "form" in str(tag): |
| 155 | + assert action_url is None |
| 156 | + action_url = urljoin(login_page_url, dict(attrs)["action"]) |
| 157 | + |
| 158 | + # Get the login page |
| 159 | + r = requests.get(url) |
| 160 | + r.raise_for_status() |
| 161 | + login_page_url = r.url # This is not the same as URL as we redirect to dex |
| 162 | + login_page_body = r.text |
| 163 | + |
| 164 | + # Search the page for the login form so we know where to post the credentials |
| 165 | + action_url = None |
| 166 | + DexLoginFormParser().feed(login_page_body) |
| 167 | + assert action_url is not None, login_page_body |
| 168 | + |
| 169 | + # Do the actual login |
| 170 | + r = requests.post( |
| 171 | + action_url, data={"login": "admin@example.com", "password": "password"} |
| 172 | + ) |
| 173 | + r.raise_for_status() |
| 174 | + # This should have redirected to the DiracX page that shows the login is complete |
| 175 | + assert "Please close the window" in r.text |
| 176 | + |
| 177 | + |
| 178 | +def make_invalid_jwt(jwt: str) -> str: |
| 179 | + """Make an invalid JWT by reversing the signature""" |
| 180 | + header, payload, signature = jwt.split(".") |
| 181 | + # JWT's don't have padding but base64.b64decode expects it |
| 182 | + raw_signature = base64.urlsafe_b64decode(pad_base64(signature)) |
| 183 | + bad_signature = base64.urlsafe_b64encode(raw_signature[::-1]) |
| 184 | + return ".".join([header, payload, bad_signature.decode("ascii").rstrip("=")]) |
| 185 | + |
| 186 | + |
| 187 | +def pad_base64(data): |
| 188 | + """Add padding to base64 data""" |
| 189 | + missing_padding = len(data) % 4 |
| 190 | + if missing_padding != 0: |
| 191 | + data += "=" * (4 - missing_padding) |
| 192 | + return data |
0 commit comments