Skip to content

Commit 345ce50

Browse files
committed
Try to parallelize a bit
1 parent 23c57c8 commit 345ce50

2 files changed

Lines changed: 234 additions & 27 deletions

File tree

astra_app/core/freeipa/user.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from collections.abc import Collection
3+
from concurrent.futures import ThreadPoolExecutor
34

45
from django.conf import settings
56
from django.utils.crypto import salted_hmac
@@ -29,6 +30,10 @@
2930

3031
logger = logging.getLogger("core.backends")
3132

33+
_LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD = 2
34+
_LIGHTWEIGHT_LOOKUP_CHUNK_SIZE = 2
35+
_LIGHTWEIGHT_LOOKUP_MAX_WORKERS = 4
36+
3237

3338
class _FreeIPAPK:
3439
attname = 'username'
@@ -423,38 +428,65 @@ def find_lightweight_by_usernames(cls, usernames: Collection[str]) -> dict[str,
423428
if not normalized_usernames:
424429
return {}
425430

426-
def _do(client: ClientMeta) -> dict[str, FreeIPAUser]:
427-
users_by_username: dict[str, FreeIPAUser] = {}
428-
for username in normalized_usernames:
429-
result = client.user_find(
430-
o_uid=username,
431-
o_all=False,
432-
o_no_members=True,
433-
o_sizelimit=1,
434-
o_timelimit=0,
435-
)
436-
if not isinstance(result, dict) or result.get("count", 0) <= 0:
437-
continue
431+
def _lookup_chunk(chunk_usernames: list[str]) -> dict[str, FreeIPAUser]:
432+
def _do(client: ClientMeta) -> dict[str, FreeIPAUser]:
433+
users_by_username: dict[str, FreeIPAUser] = {}
434+
for username in chunk_usernames:
435+
result = client.user_find(
436+
o_uid=username,
437+
o_all=False,
438+
o_no_members=True,
439+
o_sizelimit=1,
440+
o_timelimit=0,
441+
)
442+
if not isinstance(result, dict) or result.get("count", 0) <= 0:
443+
continue
438444

439-
first = (result.get("result") or [None])[0]
440-
if not isinstance(first, dict):
441-
continue
445+
first = (result.get("result") or [None])[0]
446+
if not isinstance(first, dict):
447+
continue
442448

443-
uid = first.get("uid")
444-
if isinstance(uid, list):
445-
resolved_username = (uid[0] if uid else "") or ""
446-
else:
447-
resolved_username = uid or ""
448-
resolved_username = str(resolved_username).strip().lower()
449-
if not resolved_username:
450-
continue
449+
uid = first.get("uid")
450+
if isinstance(uid, list):
451+
resolved_username = (uid[0] if uid else "") or ""
452+
else:
453+
resolved_username = uid or ""
454+
resolved_username = str(resolved_username).strip().lower()
455+
if not resolved_username:
456+
continue
451457

452-
users_by_username[resolved_username] = cls(resolved_username, first)
458+
users_by_username[resolved_username] = cls(resolved_username, first)
453459

454-
return users_by_username
460+
return users_by_username
455461

456-
try:
457462
return _with_freeipa_service_client_retry(cls.get_client, _do)
463+
464+
try:
465+
if len(normalized_usernames) <= _LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD:
466+
return _lookup_chunk(normalized_usernames)
467+
468+
chunk_size = max(1, _LIGHTWEIGHT_LOOKUP_CHUNK_SIZE)
469+
username_chunks = [
470+
normalized_usernames[index:index + chunk_size]
471+
for index in range(0, len(normalized_usernames), chunk_size)
472+
]
473+
if len(username_chunks) == 1:
474+
return _lookup_chunk(normalized_usernames)
475+
476+
max_workers = max(1, min(_LIGHTWEIGHT_LOOKUP_MAX_WORKERS, len(username_chunks)))
477+
if max_workers == 1:
478+
return _lookup_chunk(normalized_usernames)
479+
480+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
481+
chunk_futures = [
482+
executor.submit(_lookup_chunk, chunk_usernames)
483+
for chunk_usernames in username_chunks
484+
]
485+
486+
users_by_username: dict[str, FreeIPAUser] = {}
487+
for chunk_future in chunk_futures:
488+
users_by_username.update(chunk_future.result())
489+
return users_by_username
458490
except Exception as e:
459491
logger.exception(
460492
f"Failed to find lightweight users usernames={normalized_usernames}: {e}",

astra_app/core/tests/test_freeipa_user_request_time_lookup.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import threading
2+
import time
13
from collections.abc import Callable
24
from unittest.mock import patch
35

46
from django.core.cache import cache
57
from django.test import TestCase
8+
from python_freeipa import exceptions
69

10+
from core.freeipa.client import clear_freeipa_service_client_cache
711
from core.freeipa.user import FreeIPAUser
812
from core.freeipa.utils import _user_cache_key
913
from core.freeipa_directory import search_freeipa_users
@@ -46,6 +50,73 @@ def user_show(self, username: str, *args: object, **kwargs: object) -> dict[str,
4650
return {"result": self.user_show_results[username]}
4751

4852

53+
class _LookupExecutionTracker:
54+
def __init__(
55+
self,
56+
*,
57+
transient_unauthorized_by_username: dict[str, int] | None = None,
58+
fatal_usernames: set[str] | None = None,
59+
) -> None:
60+
self._lock = threading.Lock()
61+
self.client_count = 0
62+
self.active_calls = 0
63+
self.max_active_calls = 0
64+
self.lookup_counts_by_username: dict[str, int] = {}
65+
self.transient_unauthorized_by_username = dict(transient_unauthorized_by_username or {})
66+
self.fatal_usernames = set(fatal_usernames or set())
67+
68+
def register_client(self) -> int:
69+
with self._lock:
70+
self.client_count += 1
71+
return self.client_count
72+
73+
def begin_lookup(self, username: str) -> None:
74+
with self._lock:
75+
self.active_calls += 1
76+
if self.active_calls > self.max_active_calls:
77+
self.max_active_calls = self.active_calls
78+
self.lookup_counts_by_username[username] = self.lookup_counts_by_username.get(username, 0) + 1
79+
80+
def end_lookup(self) -> None:
81+
with self._lock:
82+
self.active_calls -= 1
83+
84+
def consume_transient_unauthorized(self, username: str) -> bool:
85+
with self._lock:
86+
remaining = self.transient_unauthorized_by_username.get(username, 0)
87+
if remaining <= 0:
88+
return False
89+
self.transient_unauthorized_by_username[username] = remaining - 1
90+
return True
91+
92+
93+
class _TrackingShapeLookupClient(_ShapeLookupClient):
94+
def __init__(
95+
self,
96+
*,
97+
tracker: _LookupExecutionTracker,
98+
username_results: dict[str, dict[str, object]],
99+
pause_seconds: float = 0.0,
100+
) -> None:
101+
super().__init__(username_results=username_results)
102+
self.tracker = tracker
103+
self.pause_seconds = pause_seconds
104+
105+
def user_find(self, *args: object, **kwargs: object) -> dict[str, object]:
106+
username = str(kwargs.get("o_uid") or "").strip().lower()
107+
self.tracker.begin_lookup(username)
108+
try:
109+
if self.pause_seconds:
110+
time.sleep(self.pause_seconds)
111+
if username in self.tracker.fatal_usernames:
112+
raise RuntimeError(f"boom for {username}")
113+
if self.tracker.consume_transient_unauthorized(username):
114+
raise exceptions.Unauthorized()
115+
return super().user_find(*args, **kwargs)
116+
finally:
117+
self.tracker.end_lookup()
118+
119+
49120
def _lightweight_row(username: str, *, full_name: str | None = None) -> dict[str, object]:
50121
display_name = full_name or username
51122
return {
@@ -138,7 +209,7 @@ def test_find_lightweight_by_usernames_deduplicates_and_uses_exact_shape(self) -
138209
self.assertEqual(set(users_by_username), {"alice", "bob"})
139210
self.assertEqual(users_by_username["alice"].full_name, "Alice Example")
140211
self.assertEqual(users_by_username["bob"].full_name, "Bob Example")
141-
self.assertEqual(
212+
self.assertCountEqual(
142213
client.user_find_calls,
143214
[
144215
{
@@ -158,6 +229,110 @@ def test_find_lightweight_by_usernames_deduplicates_and_uses_exact_shape(self) -
158229
],
159230
)
160231

232+
def test_find_lightweight_by_usernames_uses_serial_fast_path_for_small_inputs(self) -> None:
233+
tracker = _LookupExecutionTracker()
234+
235+
def build_client() -> _TrackingShapeLookupClient:
236+
tracker.register_client()
237+
return _TrackingShapeLookupClient(
238+
tracker=tracker,
239+
username_results={
240+
"alice": {"count": 1, "result": [_lightweight_row("alice", full_name="Alice Example")]},
241+
"bob": {"count": 1, "result": [_lightweight_row("bob", full_name="Bob Example")]},
242+
},
243+
pause_seconds=0.01,
244+
)
245+
246+
with (
247+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD", 2, create=True),
248+
patch("core.freeipa.user.FreeIPAUser.get_client", side_effect=build_client),
249+
):
250+
users_by_username = FreeIPAUser.find_lightweight_by_usernames(["alice", "bob"])
251+
252+
self.assertEqual(set(users_by_username), {"alice", "bob"})
253+
self.assertEqual(tracker.client_count, 1)
254+
self.assertEqual(tracker.max_active_calls, 1)
255+
256+
def test_find_lightweight_by_usernames_bounds_service_client_fanout(self) -> None:
257+
tracker = _LookupExecutionTracker()
258+
usernames = [f"user{index}" for index in range(8)]
259+
max_workers = 3
260+
username_results = {
261+
username: {"count": 1, "result": [_lightweight_row(username, full_name=f"{username} Example")]}
262+
for username in usernames
263+
}
264+
265+
def build_client(*_args: object, **_kwargs: object) -> _TrackingShapeLookupClient:
266+
tracker.register_client()
267+
return _TrackingShapeLookupClient(
268+
tracker=tracker,
269+
username_results=username_results,
270+
pause_seconds=0.02,
271+
)
272+
273+
with (
274+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD", 2, create=True),
275+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_CHUNK_SIZE", 2, create=True),
276+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_MAX_WORKERS", max_workers, create=True),
277+
patch("core.freeipa.client._get_freeipa_client", side_effect=build_client),
278+
):
279+
clear_freeipa_service_client_cache()
280+
users_by_username = FreeIPAUser.find_lightweight_by_usernames(usernames)
281+
clear_freeipa_service_client_cache()
282+
283+
self.assertEqual(set(users_by_username), set(usernames))
284+
self.assertGreater(tracker.client_count, 1)
285+
self.assertLessEqual(tracker.client_count, max_workers)
286+
self.assertGreater(tracker.max_active_calls, 1)
287+
288+
def test_find_lightweight_by_usernames_retries_unauthorized_within_worker_chunk(self) -> None:
289+
tracker = _LookupExecutionTracker(transient_unauthorized_by_username={"charlie": 1})
290+
usernames = ["alice", "bob", "charlie", "dave"]
291+
username_results = {
292+
username: {"count": 1, "result": [_lightweight_row(username, full_name=f"{username} Example")]}
293+
for username in usernames
294+
}
295+
296+
def build_client() -> _TrackingShapeLookupClient:
297+
tracker.register_client()
298+
return _TrackingShapeLookupClient(tracker=tracker, username_results=username_results)
299+
300+
with (
301+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD", 2, create=True),
302+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_CHUNK_SIZE", 2, create=True),
303+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_MAX_WORKERS", 2, create=True),
304+
patch("core.freeipa.user.FreeIPAUser.get_client", side_effect=build_client),
305+
):
306+
users_by_username = FreeIPAUser.find_lightweight_by_usernames(usernames)
307+
308+
self.assertEqual(set(users_by_username), set(usernames))
309+
self.assertEqual(tracker.lookup_counts_by_username.get("alice"), 1)
310+
self.assertEqual(tracker.lookup_counts_by_username.get("bob"), 1)
311+
self.assertEqual(tracker.lookup_counts_by_username.get("charlie"), 2)
312+
self.assertEqual(tracker.lookup_counts_by_username.get("dave"), 1)
313+
314+
def test_find_lightweight_by_usernames_returns_empty_dict_when_any_worker_fails(self) -> None:
315+
tracker = _LookupExecutionTracker(fatal_usernames={"charlie"})
316+
usernames = ["alice", "bob", "charlie", "dave"]
317+
username_results = {
318+
username: {"count": 1, "result": [_lightweight_row(username, full_name=f"{username} Example")]}
319+
for username in usernames
320+
}
321+
322+
def build_client() -> _TrackingShapeLookupClient:
323+
tracker.register_client()
324+
return _TrackingShapeLookupClient(tracker=tracker, username_results=username_results)
325+
326+
with (
327+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD", 2, create=True),
328+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_CHUNK_SIZE", 2, create=True),
329+
patch("core.freeipa.user._LIGHTWEIGHT_LOOKUP_MAX_WORKERS", 2, create=True),
330+
patch("core.freeipa.user.FreeIPAUser.get_client", side_effect=build_client),
331+
):
332+
users_by_username = FreeIPAUser.find_lightweight_by_usernames(usernames)
333+
334+
self.assertEqual(users_by_username, {})
335+
161336
def test_find_lightweight_by_usernames_omits_missing_usernames(self) -> None:
162337
client = _ShapeLookupClient(
163338
username_results={

0 commit comments

Comments
 (0)