1+ import threading
2+ import time
13from collections .abc import Callable
24from unittest .mock import patch
35
46from django .core .cache import cache
57from django .test import TestCase
8+ from python_freeipa import exceptions
69
10+ from core .freeipa .client import clear_freeipa_service_client_cache
711from core .freeipa .user import FreeIPAUser
812from core .freeipa .utils import _user_cache_key
913from core .freeipa_directory import search_freeipa_users
@@ -46,6 +50,73 @@ def user_show(self, username: str, *args: object, **kwargs: object) -> dict[str,
4650 return {"result" : self .user_show_results [username ]}
4751
4852
53+ class _LookupExecutionTracker :
54+ def __init__ (
55+ self ,
56+ * ,
57+ transient_unauthorized_by_username : dict [str , int ] | None = None ,
58+ fatal_usernames : set [str ] | None = None ,
59+ ) -> None :
60+ self ._lock = threading .Lock ()
61+ self .client_count = 0
62+ self .active_calls = 0
63+ self .max_active_calls = 0
64+ self .lookup_counts_by_username : dict [str , int ] = {}
65+ self .transient_unauthorized_by_username = dict (transient_unauthorized_by_username or {})
66+ self .fatal_usernames = set (fatal_usernames or set ())
67+
68+ def register_client (self ) -> int :
69+ with self ._lock :
70+ self .client_count += 1
71+ return self .client_count
72+
73+ def begin_lookup (self , username : str ) -> None :
74+ with self ._lock :
75+ self .active_calls += 1
76+ if self .active_calls > self .max_active_calls :
77+ self .max_active_calls = self .active_calls
78+ self .lookup_counts_by_username [username ] = self .lookup_counts_by_username .get (username , 0 ) + 1
79+
80+ def end_lookup (self ) -> None :
81+ with self ._lock :
82+ self .active_calls -= 1
83+
84+ def consume_transient_unauthorized (self , username : str ) -> bool :
85+ with self ._lock :
86+ remaining = self .transient_unauthorized_by_username .get (username , 0 )
87+ if remaining <= 0 :
88+ return False
89+ self .transient_unauthorized_by_username [username ] = remaining - 1
90+ return True
91+
92+
93+ class _TrackingShapeLookupClient (_ShapeLookupClient ):
94+ def __init__ (
95+ self ,
96+ * ,
97+ tracker : _LookupExecutionTracker ,
98+ username_results : dict [str , dict [str , object ]],
99+ pause_seconds : float = 0.0 ,
100+ ) -> None :
101+ super ().__init__ (username_results = username_results )
102+ self .tracker = tracker
103+ self .pause_seconds = pause_seconds
104+
105+ def user_find (self , * args : object , ** kwargs : object ) -> dict [str , object ]:
106+ username = str (kwargs .get ("o_uid" ) or "" ).strip ().lower ()
107+ self .tracker .begin_lookup (username )
108+ try :
109+ if self .pause_seconds :
110+ time .sleep (self .pause_seconds )
111+ if username in self .tracker .fatal_usernames :
112+ raise RuntimeError (f"boom for { username } " )
113+ if self .tracker .consume_transient_unauthorized (username ):
114+ raise exceptions .Unauthorized ()
115+ return super ().user_find (* args , ** kwargs )
116+ finally :
117+ self .tracker .end_lookup ()
118+
119+
49120def _lightweight_row (username : str , * , full_name : str | None = None ) -> dict [str , object ]:
50121 display_name = full_name or username
51122 return {
@@ -138,7 +209,7 @@ def test_find_lightweight_by_usernames_deduplicates_and_uses_exact_shape(self) -
138209 self .assertEqual (set (users_by_username ), {"alice" , "bob" })
139210 self .assertEqual (users_by_username ["alice" ].full_name , "Alice Example" )
140211 self .assertEqual (users_by_username ["bob" ].full_name , "Bob Example" )
141- self .assertEqual (
212+ self .assertCountEqual (
142213 client .user_find_calls ,
143214 [
144215 {
@@ -158,6 +229,110 @@ def test_find_lightweight_by_usernames_deduplicates_and_uses_exact_shape(self) -
158229 ],
159230 )
160231
232+ def test_find_lightweight_by_usernames_uses_serial_fast_path_for_small_inputs (self ) -> None :
233+ tracker = _LookupExecutionTracker ()
234+
235+ def build_client () -> _TrackingShapeLookupClient :
236+ tracker .register_client ()
237+ return _TrackingShapeLookupClient (
238+ tracker = tracker ,
239+ username_results = {
240+ "alice" : {"count" : 1 , "result" : [_lightweight_row ("alice" , full_name = "Alice Example" )]},
241+ "bob" : {"count" : 1 , "result" : [_lightweight_row ("bob" , full_name = "Bob Example" )]},
242+ },
243+ pause_seconds = 0.01 ,
244+ )
245+
246+ with (
247+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD" , 2 , create = True ),
248+ patch ("core.freeipa.user.FreeIPAUser.get_client" , side_effect = build_client ),
249+ ):
250+ users_by_username = FreeIPAUser .find_lightweight_by_usernames (["alice" , "bob" ])
251+
252+ self .assertEqual (set (users_by_username ), {"alice" , "bob" })
253+ self .assertEqual (tracker .client_count , 1 )
254+ self .assertEqual (tracker .max_active_calls , 1 )
255+
256+ def test_find_lightweight_by_usernames_bounds_service_client_fanout (self ) -> None :
257+ tracker = _LookupExecutionTracker ()
258+ usernames = [f"user{ index } " for index in range (8 )]
259+ max_workers = 3
260+ username_results = {
261+ username : {"count" : 1 , "result" : [_lightweight_row (username , full_name = f"{ username } Example" )]}
262+ for username in usernames
263+ }
264+
265+ def build_client (* _args : object , ** _kwargs : object ) -> _TrackingShapeLookupClient :
266+ tracker .register_client ()
267+ return _TrackingShapeLookupClient (
268+ tracker = tracker ,
269+ username_results = username_results ,
270+ pause_seconds = 0.02 ,
271+ )
272+
273+ with (
274+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD" , 2 , create = True ),
275+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_CHUNK_SIZE" , 2 , create = True ),
276+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_MAX_WORKERS" , max_workers , create = True ),
277+ patch ("core.freeipa.client._get_freeipa_client" , side_effect = build_client ),
278+ ):
279+ clear_freeipa_service_client_cache ()
280+ users_by_username = FreeIPAUser .find_lightweight_by_usernames (usernames )
281+ clear_freeipa_service_client_cache ()
282+
283+ self .assertEqual (set (users_by_username ), set (usernames ))
284+ self .assertGreater (tracker .client_count , 1 )
285+ self .assertLessEqual (tracker .client_count , max_workers )
286+ self .assertGreater (tracker .max_active_calls , 1 )
287+
288+ def test_find_lightweight_by_usernames_retries_unauthorized_within_worker_chunk (self ) -> None :
289+ tracker = _LookupExecutionTracker (transient_unauthorized_by_username = {"charlie" : 1 })
290+ usernames = ["alice" , "bob" , "charlie" , "dave" ]
291+ username_results = {
292+ username : {"count" : 1 , "result" : [_lightweight_row (username , full_name = f"{ username } Example" )]}
293+ for username in usernames
294+ }
295+
296+ def build_client () -> _TrackingShapeLookupClient :
297+ tracker .register_client ()
298+ return _TrackingShapeLookupClient (tracker = tracker , username_results = username_results )
299+
300+ with (
301+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD" , 2 , create = True ),
302+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_CHUNK_SIZE" , 2 , create = True ),
303+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_MAX_WORKERS" , 2 , create = True ),
304+ patch ("core.freeipa.user.FreeIPAUser.get_client" , side_effect = build_client ),
305+ ):
306+ users_by_username = FreeIPAUser .find_lightweight_by_usernames (usernames )
307+
308+ self .assertEqual (set (users_by_username ), set (usernames ))
309+ self .assertEqual (tracker .lookup_counts_by_username .get ("alice" ), 1 )
310+ self .assertEqual (tracker .lookup_counts_by_username .get ("bob" ), 1 )
311+ self .assertEqual (tracker .lookup_counts_by_username .get ("charlie" ), 2 )
312+ self .assertEqual (tracker .lookup_counts_by_username .get ("dave" ), 1 )
313+
314+ def test_find_lightweight_by_usernames_returns_empty_dict_when_any_worker_fails (self ) -> None :
315+ tracker = _LookupExecutionTracker (fatal_usernames = {"charlie" })
316+ usernames = ["alice" , "bob" , "charlie" , "dave" ]
317+ username_results = {
318+ username : {"count" : 1 , "result" : [_lightweight_row (username , full_name = f"{ username } Example" )]}
319+ for username in usernames
320+ }
321+
322+ def build_client () -> _TrackingShapeLookupClient :
323+ tracker .register_client ()
324+ return _TrackingShapeLookupClient (tracker = tracker , username_results = username_results )
325+
326+ with (
327+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_SERIAL_THRESHOLD" , 2 , create = True ),
328+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_CHUNK_SIZE" , 2 , create = True ),
329+ patch ("core.freeipa.user._LIGHTWEIGHT_LOOKUP_MAX_WORKERS" , 2 , create = True ),
330+ patch ("core.freeipa.user.FreeIPAUser.get_client" , side_effect = build_client ),
331+ ):
332+ users_by_username = FreeIPAUser .find_lightweight_by_usernames (usernames )
333+
334+ self .assertEqual (users_by_username , {})
335+
161336 def test_find_lightweight_by_usernames_omits_missing_usernames (self ) -> None :
162337 client = _ShapeLookupClient (
163338 username_results = {
0 commit comments