Skip to content

Commit cefdc61

Browse files
committed
Improve CLI login tests
1 parent 3366c09 commit cefdc61

2 files changed

Lines changed: 140 additions & 43 deletions

File tree

tests/cli/conftest.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,11 @@ def cli_env(monkeypatch, tmp_path, demo_dir):
3434

3535
@pytest.fixture
3636
async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path):
37-
from .test_login import test_login
37+
from .test_login import do_successful_login
3838

3939
try:
40-
credentials = await test_login(monkeypatch, capfd, cli_env)
40+
await do_successful_login(monkeypatch, capfd, cli_env)
4141
except Exception:
42-
pytest.skip("Login failed, fix test_login to re-enable this test")
42+
pytest.xfail("Login failed, fix test_login to re-enable this test")
4343

44-
credentials_path = tmp_path / "credentials.json"
45-
credentials_path.write_text(credentials)
46-
monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path))
4744
yield

tests/cli/test_login.py

Lines changed: 137 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,112 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import base64
5+
import json
46
import re
57
from html.parser import HTMLParser
68
from pathlib import Path
79
from urllib.parse import urljoin
810

11+
import pytest
912
import requests
1013

1114
from diracx import cli
1215

1316

14-
def do_device_flow_with_dex(url: str) -> None:
15-
"""Do the device flow with dex"""
17+
async def test_login(monkeypatch, capfd, cli_env):
18+
"""Test that the CLI can login successfully"""
19+
expected_credentials_path = Path(
20+
cli_env["HOME"], ".cache", "diracx", "credentials.json"
21+
)
1622

17-
class DexLoginFormParser(HTMLParser):
18-
def handle_starttag(self, tag, attrs):
19-
nonlocal action_url
20-
if "form" in str(tag):
21-
assert action_url is None
22-
action_url = urljoin(login_page_url, dict(attrs)["action"])
23+
# Ensure the credentials file does not exist before logging in
24+
assert not expected_credentials_path.exists()
2325

24-
# Get the login page
25-
r = requests.get(url)
26-
r.raise_for_status()
27-
login_page_url = r.url # This is not the same as URL as we redirect to dex
28-
login_page_body = r.text
26+
# Do the actual login
27+
await do_successful_login(monkeypatch, capfd, cli_env)
2928

30-
# Search the page for the login form so we know where to post the credentials
31-
action_url = None
32-
DexLoginFormParser().feed(login_page_body)
33-
assert action_url is not None, login_page_body
29+
# Ensure the credentials file exists after logging in
30+
assert expected_credentials_path.exists()
31+
32+
33+
async def test_invalid_credentials_file(monkeypatch, capfd, cli_env):
34+
"""Test that the CLI can handle an invalid credentials file"""
35+
expected_credentials_path = Path(
36+
cli_env["HOME"], ".cache", "diracx", "credentials.json"
37+
)
38+
expected_credentials_path.parent.mkdir(parents=True, exist_ok=True)
39+
expected_credentials_path.write_text("invalid json")
3440

3541
# Do the actual login
36-
r = requests.post(
37-
action_url, data={"login": "admin@example.com", "password": "password"}
42+
await do_successful_login(monkeypatch, capfd, cli_env)
43+
44+
45+
async def test_invalid_access_token(cli_env, monkeypatch, capfd, with_cli_login):
46+
"""Test that the CLI can handle an invalid access token
47+
48+
We expect the CLI to detect the invalid access token and use the refresh
49+
token to get a new access token without prompting the user to login again.
50+
"""
51+
expected_credentials_path = Path(
52+
cli_env["HOME"], ".cache", "diracx", "credentials.json"
3853
)
39-
r.raise_for_status()
40-
# This should have redirected to the DiracX page that shows the login is complete
41-
assert "Please close the window" in r.text
4254

55+
credentials = json.loads(expected_credentials_path.read_text())
56+
bad_credentials = credentials | {
57+
"access_token": make_invalid_jwt(credentials["access_token"]),
58+
"expires_on": credentials["expires_on"] - 3600,
59+
}
60+
expected_credentials_path.write_text(json.dumps(bad_credentials))
4361

44-
async def test_login(monkeypatch, capfd, cli_env):
62+
# See if the credentials still work
63+
await cli.whoami()
64+
cap = capfd.readouterr()
65+
assert cap.err == ""
66+
assert json.loads(cap.out)["vo"] == "diracAdmin"
67+
68+
69+
@pytest.mark.xfail(reason="TODO: Implement nicer error handling in the CLI")
70+
async def test_invalid_refresh_token(cli_env, monkeypatch, capfd, with_cli_login):
71+
"""Test that the CLI can handle an invalid refresh token
72+
73+
We expect the CLI to detect the invalid refresh token and prompt the user
74+
to login again.
75+
"""
76+
expected_credentials_path = Path(
77+
cli_env["HOME"], ".cache", "diracx", "credentials.json"
78+
)
79+
80+
credentials = json.loads(expected_credentials_path.read_text())
81+
bad_credentials = credentials | {
82+
"refresh_token": make_invalid_jwt(credentials["refresh_token"]),
83+
"expires_on": credentials["expires_on"] - 3600,
84+
}
85+
expected_credentials_path.write_text(json.dumps(bad_credentials))
86+
87+
with pytest.raises(SystemExit):
88+
await cli.whoami()
89+
cap = capfd.readouterr()
90+
assert cap.out == ""
91+
assert "dirac login" in cap.err
92+
93+
# Having invalid credentials should prompt the user to login again
94+
await do_successful_login(monkeypatch, capfd, cli_env)
95+
96+
# See if the credentials work
97+
await cli.whoami()
98+
cap = capfd.readouterr()
99+
assert cap.err == ""
100+
assert json.loads(cap.out)["vo"] == "diracAdmin"
101+
102+
103+
# ###############################################
104+
# The rest of this file contains helper functions
105+
# ###############################################
106+
107+
108+
async def do_successful_login(monkeypatch, capfd, cli_env):
109+
"""Do a successful login using the CLI"""
45110
poll_attempts = 0
46111

47112
def fake_sleep(*args, **kwargs):
@@ -69,24 +134,59 @@ def fake_sleep(*args, **kwargs):
69134
# would normally be done by a user. This includes capturing the login URL
70135
# and doing the actual device flow with dex.
71136
unpatched_sleep = asyncio.sleep
72-
monkeypatch.setattr("asyncio.sleep", fake_sleep)
73-
74-
expected_credentials_path = Path(
75-
cli_env["HOME"], ".cache", "diracx", "credentials.json"
76-
)
137+
with monkeypatch.context() as m:
138+
m.setattr("asyncio.sleep", fake_sleep)
77139

78-
# Ensure the credentials file does not exist before logging in
79-
assert not expected_credentials_path.exists()
140+
# Run the login command
141+
await cli.login(vo="diracAdmin", group=None, property=None)
80142

81-
# Run the login command
82-
await cli.login(vo="diracAdmin", group=None, property=None)
83143
captured = capfd.readouterr()
84144
assert "Login successful!" in captured.out
85145
assert captured.err == ""
86146

87-
# Ensure the credentials file exists after logging in
88-
assert expected_credentials_path.exists()
89147

90-
# Return the credentials so this test can also be used by the
91-
# "with_cli_login" fixture
92-
return expected_credentials_path.read_text()
148+
def do_device_flow_with_dex(url: str) -> None:
149+
"""Do the device flow with dex"""
150+
151+
class DexLoginFormParser(HTMLParser):
152+
def handle_starttag(self, tag, attrs):
153+
nonlocal action_url
154+
if "form" in str(tag):
155+
assert action_url is None
156+
action_url = urljoin(login_page_url, dict(attrs)["action"])
157+
158+
# Get the login page
159+
r = requests.get(url)
160+
r.raise_for_status()
161+
login_page_url = r.url # This is not the same as URL as we redirect to dex
162+
login_page_body = r.text
163+
164+
# Search the page for the login form so we know where to post the credentials
165+
action_url = None
166+
DexLoginFormParser().feed(login_page_body)
167+
assert action_url is not None, login_page_body
168+
169+
# Do the actual login
170+
r = requests.post(
171+
action_url, data={"login": "admin@example.com", "password": "password"}
172+
)
173+
r.raise_for_status()
174+
# This should have redirected to the DiracX page that shows the login is complete
175+
assert "Please close the window" in r.text
176+
177+
178+
def make_invalid_jwt(jwt: str) -> str:
179+
"""Make an invalid JWT by reversing the signature"""
180+
header, payload, signature = jwt.split(".")
181+
# JWT's don't have padding but base64.b64decode expects it
182+
raw_signature = base64.urlsafe_b64decode(pad_base64(signature))
183+
bad_signature = base64.urlsafe_b64encode(raw_signature[::-1])
184+
return ".".join([header, payload, bad_signature.decode("ascii").rstrip("=")])
185+
186+
187+
def pad_base64(data):
188+
"""Add padding to base64 data"""
189+
missing_padding = len(data) % 4
190+
if missing_padding != 0:
191+
data += "=" * (4 - missing_padding)
192+
return data

0 commit comments

Comments
 (0)