55
66from ldap3 import Connection , SIMPLE
77from ldap3 .core .exceptions import LDAPAttributeError
8+ from ldap3 .utils .conv import escape_filter_chars
89
910from auth import auth_base
1011from model import model_helper
@@ -38,19 +39,21 @@ def _resolve_base_dn(full_username):
3839 return ''
3940
4041
41- def _search (dn , search_filter , attributes , connection ):
42- success = connection .search (dn , search_filter , attributes = attributes )
42+ def _search (dn , search_request , attributes , connection ):
43+ search_string = search_request .as_search_string ()
44+
45+ success = connection .search (dn , search_string , attributes = attributes )
4346 if not success :
4447 if connection .last_error :
4548 LOGGER .warning ('ldap search failed: ' + connection .last_error
46- + '. dn:' + dn + ', filter: ' + search_filter )
49+ + '. dn:' + dn + ', filter: ' + search_string )
4750 return None
4851
4952 return connection .entries
5053
5154
52- def _load_multiple_entries_values (dn , search_filter , attribute_name , connection ):
53- entries = _search (dn , search_filter , [attribute_name ], connection )
55+ def _load_multiple_entries_values (dn , search_request , attribute_name , connection ):
56+ entries = _search (dn , search_request , [attribute_name ], connection )
5457 if entries is None :
5558 return []
5659
@@ -174,12 +177,13 @@ def _fetch_user_groups(self, user_dn, user_uid, connection):
174177
175178 result = set ()
176179
177- result .update (_load_multiple_entries_values (base_dn , '(member=%s)' % user_dn , 'cn' , connection ))
180+ result .update (
181+ _load_multiple_entries_values (base_dn , SearchRequest ('(member=%s)' , user_dn ), 'cn' , connection ))
178182
179183 if user_uid :
180184 result .update (_load_multiple_entries_values (
181185 base_dn ,
182- '(&(objectClass=posixGroup)(memberUid=%s))' % user_uid ,
186+ SearchRequest ( '(&(objectClass=posixGroup)(memberUid=%s))' , user_uid ) ,
183187 'cn' ,
184188 connection ))
185189
@@ -191,23 +195,23 @@ def _get_user_ids(self, full_username, connection):
191195 username_lower = full_username .lower ()
192196 if ',dc=' in username_lower :
193197 base_dn = username_lower
194- search_filter = '(objectClass=*)'
198+ search_request = SearchRequest ( '(objectClass=*)' )
195199 elif '@' in full_username :
196- search_filter = '(userPrincipalName=%s)' % full_username
200+ search_request = SearchRequest ( '(userPrincipalName=%s)' , full_username )
197201 elif '\\ ' in full_username :
198202 username_index = full_username .rfind ('\\ ' ) + 1
199203 username = full_username [username_index :]
200- search_filter = '(sAMAccountName=%s)' % username
204+ search_request = SearchRequest ( '(sAMAccountName=%s)' , username )
201205 else :
202206 LOGGER .warning ('Unsupported username pattern for ' + full_username )
203207 return full_username , None
204208
205- entries = _search (base_dn , search_filter , ['uid' ], connection )
209+ entries = _search (base_dn , search_request , ['uid' ], connection )
206210 if not entries :
207211 return full_username , None
208212
209213 if len (entries ) > 1 :
210- LOGGER .warning ('More than one user found by filter: ' + search_filter )
214+ LOGGER .warning ('More than one user found by filter: ' + str ( search_request ) )
211215 return full_username , None
212216
213217 entry = entries [0 ]
@@ -225,3 +229,15 @@ def _set_user_groups(self, user, groups):
225229
226230 new_groups_content = json .dumps (self ._user_groups , indent = 2 )
227231 file_utils .write_file (self ._groups_file , new_groups_content )
232+
233+
234+ class SearchRequest :
235+ def __init__ (self , template , * variables ) -> None :
236+ escaped_vars = [escape_filter_chars (var ) for var in variables ]
237+ self .search_string = template % tuple (escaped_vars )
238+
239+ def as_search_string (self ):
240+ return self .search_string
241+
242+ def __str__ (self ) -> str :
243+ return self .as_search_string ()
0 commit comments