From 5326b2ea300a20d2a41a4c682495e8b6d6b72c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 10:32:41 +0200 Subject: [PATCH 1/9] LDAP refactoring WIP --- seacatauth/credentials/providers/ldap.py | 504 +++++++++++------------ 1 file changed, 239 insertions(+), 265 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 7cadbfa6..f2349340 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -2,6 +2,7 @@ import base64 import datetime import contextlib +import typing from typing import Optional @@ -30,10 +31,6 @@ } -class LDAPObject(ldap.ldapobject.LDAPObject, ldap.resiter.ResultProcessor): - pass - - class LDAPCredentialsService(asab.Service): def __init__(self, app, service_name="seacatauth.credentials.ldap"): @@ -85,145 +82,114 @@ def __init__(self, provider_id, config_section_name, proactor_svc): # synchronous library (python-ldap) to be used from asynchronous code self.ProactorService = proactor_svc - attr = set(self.Config["attributes"].split(" ")) - attr.add("createTimestamp") - attr.add("modifyTimestamp") - attr.add("cn") - attr.add(self.Config["attrusername"]) - self.AttrList = list(attr) + self.LdapUri = self.Config["uri"] + self.Base = self.Config["base"] + self.AttrList = _prepare_attributes(self.Config) # Fields to filter by when locating a user - self._locate_filter_fields = ["cn", "mail", "mobile"] + self.IdentFields = ["mail", "mobile"] # If attrusername field is not empty, locate by it too if len(self.Config["attrusername"]) > 0: - self._locate_filter_fields.append(self.Config["attrusername"]) - - - async def get_login_descriptors(self, credentials_id): - return [{ - "id": "default", - "label": "Use recommended login.", - "factors": [{ - "id": "password", - "type": "password" - }], - }] + self.IdentFields.append(self.Config["attrusername"]) - @contextlib.contextmanager - def _ldap_client(self): - ldap_client = LDAPObject(self.Config["uri"]) - ldap_client.protocol_version = ldap.VERSION3 - ldap_client.set_option(ldap.OPT_REFERRALS, 0) - - network_timeout = int(self.Config.get("network_timeout")) - ldap_client.set_option(ldap.OPT_NETWORK_TIMEOUT, network_timeout) + async def get(self, credentials_id, include=None) -> Optional[dict]: + if not credentials_id.startswith(self.Prefix): + raise KeyError("Credentials {!r} not found".format(credentials_id)) + return await self.ProactorService.execute(self._get_worker, credentials_id, include) - # Enable TLS - if self.Config["uri"].startswith("ldaps"): - self._enable_tls(ldap_client) - ldap_client.simple_bind_s(self.Config["username"], self.Config["password"]) + async def search(self, filter: dict = None, **kwargs) -> list: + # TODO: Implement pagination + filterstr = self._build_search_filter(filter) + return await self.ProactorService.execute(self._search_worker, filterstr) - try: - yield ldap_client - finally: - ldap_client.unbind_s() + async def count(self, filtr=None) -> int: + filterstr = self._build_search_filter(filtr) + return await self.ProactorService.execute(self._count_worker, filterstr) - def _enable_tls(self, ldap_client): - tls_cafile = self.Config["tls_cafile"] - # Add certificate authority - if len(tls_cafile) > 0: - ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, tls_cafile) + async def iterate(self, offset: int = 0, limit: int = -1, filtr: str = None): + filterstr = self._build_search_filter(filtr) + results = await self.ProactorService.execute(self._search_worker, filterstr) + for i in results[offset : (None if limit == -1 else limit + offset)]: + yield i - # Set cert policy - if self.Config["tls_require_cert"] == "never": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) - elif self.Config["tls_require_cert"] == "demand": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - elif self.Config["tls_require_cert"] == "allow": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) - elif self.Config["tls_require_cert"] == "hard": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_HARD) - else: - L.error("Invalid 'tls_require_cert' value: {!r}. Defaulting to 'demand'.".format( - self.Config["tls_require_cert"] - )) - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - # Misc TLS options - tls_protocol_min = self.Config["tls_protocol_min"] - if tls_protocol_min != "": - if tls_protocol_min not in _TLS_VERSION: - raise ValueError("'tls_protocol_min' must be one of {} or empty.".format(list(_TLS_VERSION))) - ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, _TLS_VERSION[tls_protocol_min]) + async def locate(self, ident: str, ident_fields: dict = None, login_dict: dict = None) -> str: + return await self.ProactorService.execute(self._locate_worker, ident, ident_fields) - tls_protocol_max = self.Config["tls_protocol_max"] - if tls_protocol_max != "": - if tls_protocol_max not in _TLS_VERSION: - raise ValueError("'tls_protocol_max' must be one of {} or empty.".format(list(_TLS_VERSION))) - ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MAX, _TLS_VERSION[tls_protocol_max]) - if self.Config["tls_cipher_suite"] != "": - ldap_client.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, self.Config["tls_cipher_suite"]) + async def authenticate(self, credentials_id: str, credentials: dict) -> bool: + return await self.ProactorService.execute(self._authenticate_worker, credentials_id, credentials) - # NEWCTX needs to be the last option, because it applies all the prepared options to the new context - ldap_client.set_option(ldap.OPT_X_TLS_NEWCTX, 0) - def _get_worker(self, prefix, credentials_id, include=None) -> Optional[dict]: + async def get_login_descriptors(self, credentials_id): + # Only login with password is supported + return [{ + "id": "default", + "label": "Use recommended login.", + "factors": [{ + "id": "password", + "type": "password" + }], + }] - # TODO: Validate credetials_id with regex - cn = base64.urlsafe_b64decode(credentials_id[len(prefix):]).decode("utf-8") + def _get_worker(self, credentials_id, include=None) -> Optional[dict]: + cn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") with self._ldap_client() as lc: try: - sr = lc.search_s( + results = lc.search_s( cn, ldap.SCOPE_BASE, filterstr=self.Config["filter"], attrlist=self.AttrList, ) except ldap.NO_SUCH_OBJECT as e: - L.error(e) - sr = [] + raise KeyError("Credentials {!r} not found".format(credentials_id)) from e - if len(sr) == 0: + if len(results) > 1: + L.exception("Multiple credentials matched ID.", struct_data={"cid": credentials_id}) raise KeyError("Credentials {!r} not found".format(credentials_id)) - assert len(sr) == 1 - dn, entry = sr[0] + dn, entry = results[0] + return self._normalize_credentials(dn, entry) - return _normalize_entry( - prefix, - self.Type, - self.ProviderID, - dn, - entry, - self.Config["attrusername"] - ) + def _search_worker(self, filterstr): + # TODO: sorting + result = [] - async def get(self, credentials_id, include=None) -> Optional[dict]: - prefix = "{}:{}:".format(self.Type, self.ProviderID) - if not credentials_id.startswith(prefix): - raise KeyError("Credentials {!r} not found".format(credentials_id)) + with self._ldap_client() as ldap_client: + msgid = ldap_client.search( + self.Base, + ldap.SCOPE_SUBTREE, + filterstr=filterstr, + attrlist=self.AttrList, + ) + result_iter = ldap_client.allresults(msgid) + + for res_type, res_data, res_msgid, res_controls in result_iter: + for dn, entry in res_data: + if dn is not None: + result.append(self._normalize_credentials(dn, entry)) - return await self.ProactorService.execute(self._get_worker, prefix, credentials_id, include) + return result def _count_worker(self, filterstr): count = 0 - with self._ldap_client() as lc: - msgid = lc.search( + with self._ldap_client() as ldap_client: + msgid = ldap_client.search( self.Config["base"], ldap.SCOPE_SUBTREE, filterstr=filterstr, attrsonly=1, # If attrsonly is non-zero attrlist=["cn", "mail", "mobile"], # For counting, we need only absolutely minimum set of attributes ) - result_iter = lc.allresults(msgid) + result_iter = ldap_client.allresults(msgid) for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: @@ -235,211 +201,210 @@ def _count_worker(self, filterstr): return count - async def count(self, filtr=None) -> int: - filterstr = self._build_search_filter(filtr) - return await self.ProactorService.execute(self._count_worker, filterstr) - - - def _search_worker(self, filterstr): - - # TODO: sorting - prefix = "{}:{}:".format(self.Type, self.ProviderID) - result = [] - - with self._ldap_client() as lc: - msgid = lc.search( - self.Config["base"], - ldap.SCOPE_SUBTREE, - filterstr=filterstr, - attrlist=self.AttrList, - ) - result_iter = lc.allresults(msgid) - - for res_type, res_data, res_msgid, res_controls in result_iter: - for dn, entry in res_data: - if dn is not None: - result.append(_normalize_entry( - prefix, - self.Type, - self.ProviderID, - dn, - entry, - self.Config["attrusername"] - )) - - return result - - - async def search(self, filter: dict = None, **kwargs) -> list: - # TODO: Implement filtering and pagination - if filter is not None: - return [] - filterstr = self.Config["filter"] - return await self.ProactorService.execute(self._search_worker, filterstr) - - - async def iterate(self, offset: int = 0, limit: int = -1, filtr: str = None): - filterstr = self._build_search_filter(filtr) - arr = await self.ProactorService.execute(self._search_worker, filterstr) - for i in arr[offset:None if limit == -1 else limit + offset]: - yield i - - def _build_search_filter(self, filtr=None): - if not filtr: - filterstr = self.Config["filter"] - else: - # The query filter is the intersection of the filter from config - # and the filter defined by the search request - # The username must START WITH the given filter string - filter_template = "(&{}({}=*%s*))".format(self.Config["filter"], self.Config["attrusername"]) - assertion_values = ["{}".format(filtr.lower())] - filterstr = ldap.filter.filter_format( - filter_template=filter_template, - assertion_values=assertion_values - ) - return filterstr - - def _locate_worker(self, ident: str): - with self._ldap_client() as lc: - - # Build the filter template - # Example: (|(cn=%s)(mail=%s)(mobile=%s)(sAMAccountName=%s)) - filter_template = "(|{})".format( - "".join("({}=%s)".format(field) for field in self._locate_filter_fields) - ) - assertion_values = tuple( - ident for _ in self._locate_filter_fields - ) - - msgid = lc.search( + def _locate_worker(self, ident: str, ident_fields: typing.Optional[typing.Mapping[str, str]]): + # TODO: Implement configurable ident_fields support + with self._ldap_client() as ldap_client: + msgid = ldap_client.search( self.Config["base"], ldap.SCOPE_SUBTREE, filterstr=ldap.filter.filter_format( - filter_template=filter_template, - assertion_values=assertion_values + # Build the filter template + # Example: (|(cn=%s)(mail=%s)(mobile=%s)(sAMAccountName=%s)) + filter_template="(|{})".format( + "".join("({}=%s)".format(field) for field in self.IdentFields)), + assertion_values=tuple(ident for _ in self.IdentFields) ), - attrlist=["cn"] + attrlist=["cn"], ) - result_iter = lc.allresults(msgid) + result_iter = ldap_client.allresults(msgid) for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: if dn is not None: - return "{}:{}:{}".format( - self.Type, - self.ProviderID, - base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii"), - ) + return self._format_credentials_id(dn) return None - async def locate(self, ident: str, ident_fields: dict = None, login_dict: dict = None) -> str: - # TODO: Implement ident_fields support - """ - Locate search for the exact match of provided ident and the username in the htpasswd file - """ - return await self.ProactorService.execute(self._locate_worker, ident) - - def _authenticate_worker(self, credentials_id: str, credentials: dict) -> bool: - prefix = "{}:{}:".format(self.Type, self.ProviderID) - password = credentials.get("password") - dn = base64.urlsafe_b64decode(credentials_id[len(prefix):]).decode("utf-8") + dn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") - lc = LDAPObject(self.Config["uri"]) - lc.protocol_version = ldap.VERSION3 - lc.set_option(ldap.OPT_REFERRALS, 0) + ldap_client = _LDAPObject(self.LdapUri) + ldap_client.protocol_version = ldap.VERSION3 + ldap_client.set_option(ldap.OPT_REFERRALS, 0) # Enable TLS - if self.Config["uri"].startswith("ldaps"): - self._enable_tls(lc) + if self.LdapUri.startswith("ldaps"): + self._enable_tls(ldap_client) try: - lc.simple_bind_s(dn, password) + ldap_client.simple_bind_s(dn, password) except ldap.INVALID_CREDENTIALS: L.log(asab.LOG_NOTICE, "Authentication failed: Invalid LDAP credentials.", struct_data={ "cid": credentials_id, "dn": dn}) return False - lc.unbind_s() + ldap_client.unbind_s() return True - async def authenticate(self, credentials_id: str, credentials: dict) -> bool: - return await self.ProactorService.execute(self._authenticate_worker, credentials_id, credentials) + def _normalize_credentials(self, dn: str, search_result: typing.Mapping): + ret = { + "_id": self._format_credentials_id(dn), + "_type": self.Type, + "_provider_id": self.ProviderID, + } + + decoded_result = {"dn": dn} + for k, v in search_result.items(): + if k =="userPassword": + continue + if isinstance(v, list): + if len(v) == 0: + continue + elif len(v) == 1: + decoded_result[k] = v[0].decode("utf-8") + else: + decoded_result[k] = [i.decode("utf-8") for i in v] + + v = decoded_result.pop(self.Config["attrusername"], None) + if v is not None: + ret["username"] = v + else: + # This is fallback, since we need a username on various places + ret["username"] = dn + v = decoded_result.pop("cn", None) + if v is not None: + ret["full_name"] = v -def _normalize_entry(prefix, ptype, provider_id, dn, entry, attrusername: str = "cn"): - ret = { - "_id": prefix + base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii"), - "_type": ptype, - "_provider_id": provider_id, - } + v = decoded_result.pop("mail", None) + if v is not None: + ret["email"] = v - ldap_obj = { - "dn": dn, - } - for k, v in entry.items(): - if k in frozenset(["userPassword"]): - continue - if isinstance(v, list): - if len(v) == 1: - v = v[0].decode("utf-8") - else: - v = [i.decode("utf-8") for i in v] - ldap_obj[k] = v - - v = ldap_obj.pop(attrusername, None) - if v is not None: - ret["username"] = v - else: - # This is fallback, since we need a username on various places - ret["username"] = dn - - v = ldap_obj.pop("cn", None) - if v is not None: - ret["full_name"] = v - - v = ldap_obj.pop("mail", None) - if v is not None: - ret["email"] = v - - v = ldap_obj.pop("mobile", None) - if v is not None: - ret["phone"] = v - - v = ldap_obj.pop("userAccountControl", None) - if v is not None: - # userAccountControl is an array of binary flags returned as a decimal integer - # byte #1 is ACCOUNTDISABLE which corresponds to "suspended" status - # https://learn.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties - try: - ret["suspended"] = int(v) & 2 == 2 - except ValueError: - pass - - v = ldap_obj.pop("createTimestamp", None) - if v is not None: - ret["_c"] = _parse_timestamp(v) - else: - v = ldap_obj.pop("createTimeStamp", None) + v = decoded_result.pop("mobile", None) + if v is not None: + ret["phone"] = v + + v = decoded_result.pop("userAccountControl", None) + if v is not None: + # userAccountControl is an array of binary flags returned as a decimal integer + # byte #1 is ACCOUNTDISABLE which corresponds to "suspended" status + # https://learn.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties + try: + ret["suspended"] = int(v) & 2 == 2 + except ValueError: + pass + + v = decoded_result.pop("createTimestamp", None) if v is not None: ret["_c"] = _parse_timestamp(v) + else: + v = decoded_result.pop("createTimeStamp", None) + if v is not None: + ret["_c"] = _parse_timestamp(v) - v = ldap_obj.pop("modifyTimestamp", None) - if v is not None: - ret["_m"] = _parse_timestamp(v) - else: - v = ldap_obj.pop("modifyTimeStamp", None) + v = decoded_result.pop("modifyTimestamp", None) if v is not None: ret["_m"] = _parse_timestamp(v) + else: + v = decoded_result.pop("modifyTimeStamp", None) + if v is not None: + ret["_m"] = _parse_timestamp(v) + + if len(decoded_result) > 0: + ret["_ldap"] = decoded_result + + return ret + + + @contextlib.contextmanager + def _ldap_client(self): + ldap_client = _LDAPObject(self.LdapUri) + ldap_client.protocol_version = ldap.VERSION3 + ldap_client.set_option(ldap.OPT_REFERRALS, 0) + + network_timeout = self.Config.getint("network_timeout") + ldap_client.set_option(ldap.OPT_NETWORK_TIMEOUT, network_timeout) + + # Enable TLS + if self.LdapUri.startswith("ldaps"): + self._enable_tls(ldap_client) + + ldap_client.simple_bind_s(self.Config["username"], self.Config["password"]) - if len(ldap_obj) > 0: - ret["_ldap"] = ldap_obj + try: + yield ldap_client - return ret + finally: + ldap_client.unbind_s() + + + def _enable_tls(self, ldap_client): + tls_cafile = self.Config["tls_cafile"] + + # Add certificate authority + if len(tls_cafile) > 0: + ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, tls_cafile) + + # Set cert policy + if self.Config["tls_require_cert"] == "never": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) + elif self.Config["tls_require_cert"] == "demand": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + elif self.Config["tls_require_cert"] == "allow": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) + elif self.Config["tls_require_cert"] == "hard": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_HARD) + else: + L.error("Invalid 'tls_require_cert' value: {!r}. Defaulting to 'demand'.".format( + self.Config["tls_require_cert"] + )) + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + + # Misc TLS options + tls_protocol_min = self.Config["tls_protocol_min"] + if tls_protocol_min != "": + if tls_protocol_min not in _TLS_VERSION: + raise ValueError("'tls_protocol_min' must be one of {} or empty.".format(list(_TLS_VERSION))) + ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, _TLS_VERSION[tls_protocol_min]) + + tls_protocol_max = self.Config["tls_protocol_max"] + if tls_protocol_max != "": + if tls_protocol_max not in _TLS_VERSION: + raise ValueError("'tls_protocol_max' must be one of {} or empty.".format(list(_TLS_VERSION))) + ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MAX, _TLS_VERSION[tls_protocol_max]) + + if self.Config["tls_cipher_suite"] != "": + ldap_client.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, self.Config["tls_cipher_suite"]) + + # NEWCTX needs to be the last option, because it applies all the prepared options to the new context + ldap_client.set_option(ldap.OPT_X_TLS_NEWCTX, 0) + + + def _format_credentials_id(self, dn): + return self.Prefix + base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii") + + + def _build_search_filter(self, filtr: typing.Optional[str] = None): + if not filtr: + filterstr = self.Config["filter"] + else: + # The query filter is the intersection of the filter from config + # and the filter defined by the search request + # The username must START WITH the given filter string + filter_template = "(&{}({}=*%s*))".format(self.Config["filter"], self.Config["attrusername"]) + assertion_values = ["{}".format(filtr.lower())] + filterstr = ldap.filter.filter_format( + filter_template=filter_template, + assertion_values=assertion_values + ) + return filterstr + + +class _LDAPObject(ldap.ldapobject.LDAPObject, ldap.resiter.ResultProcessor): + pass def _parse_timestamp(ts: str) -> datetime.datetime: @@ -449,3 +414,12 @@ def _parse_timestamp(ts: str) -> datetime.datetime: pass return datetime.datetime.strptime(ts, r"%Y%m%d%H%M%S.%fZ") + + +def _prepare_attributes(config: typing.Mapping): + attr = set(config["attributes"].split(" ")) + attr.add("createTimestamp") + attr.add("modifyTimestamp") + attr.add("cn") + attr.add(config["attrusername"]) + return list(attr) From 137a08bfbdded670ac61a7fed22983a88ea28d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 10:37:38 +0200 Subject: [PATCH 2/9] LDAP refactoring WIP --- seacatauth/credentials/providers/ldap.py | 53 ++++++++++++------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index f2349340..05d15507 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -84,6 +84,7 @@ def __init__(self, provider_id, config_section_name, proactor_svc): self.LdapUri = self.Config["uri"] self.Base = self.Config["base"] + self.Filter = self.Config["filter"] self.AttrList = _prepare_attributes(self.Config) # Fields to filter by when locating a user @@ -144,7 +145,7 @@ def _get_worker(self, credentials_id, include=None) -> Optional[dict]: results = lc.search_s( cn, ldap.SCOPE_BASE, - filterstr=self.Config["filter"], + filterstr=self.Filter, attrlist=self.AttrList, ) except ldap.NO_SUCH_OBJECT as e: @@ -160,7 +161,7 @@ def _get_worker(self, credentials_id, include=None) -> Optional[dict]: def _search_worker(self, filterstr): # TODO: sorting - result = [] + results = [] with self._ldap_client() as ldap_client: msgid = ldap_client.search( @@ -174,16 +175,16 @@ def _search_worker(self, filterstr): for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: if dn is not None: - result.append(self._normalize_credentials(dn, entry)) + results.append(self._normalize_credentials(dn, entry)) - return result + return results def _count_worker(self, filterstr): count = 0 with self._ldap_client() as ldap_client: msgid = ldap_client.search( - self.Config["base"], + self.Base, ldap.SCOPE_SUBTREE, filterstr=filterstr, attrsonly=1, # If attrsonly is non-zero @@ -193,9 +194,7 @@ def _count_worker(self, filterstr): for res_type, res_data, res_msgid, res_controls in result_iter: for dn, entry in res_data: - if dn is None: - continue - else: + if dn is not None: count += 1 return count @@ -205,7 +204,7 @@ def _locate_worker(self, ident: str, ident_fields: typing.Optional[typing.Mappin # TODO: Implement configurable ident_fields support with self._ldap_client() as ldap_client: msgid = ldap_client.search( - self.Config["base"], + self.Base, ldap.SCOPE_SUBTREE, filterstr=ldap.filter.filter_format( # Build the filter template @@ -249,45 +248,45 @@ def _authenticate_worker(self, credentials_id: str, credentials: dict) -> bool: return True - def _normalize_credentials(self, dn: str, search_result: typing.Mapping): + def _normalize_credentials(self, dn: str, search_record: typing.Mapping): ret = { "_id": self._format_credentials_id(dn), "_type": self.Type, "_provider_id": self.ProviderID, } - decoded_result = {"dn": dn} - for k, v in search_result.items(): + decoded_record = {"dn": dn} + for k, v in search_record.items(): if k =="userPassword": continue if isinstance(v, list): if len(v) == 0: continue elif len(v) == 1: - decoded_result[k] = v[0].decode("utf-8") + decoded_record[k] = v[0].decode("utf-8") else: - decoded_result[k] = [i.decode("utf-8") for i in v] + decoded_record[k] = [i.decode("utf-8") for i in v] - v = decoded_result.pop(self.Config["attrusername"], None) + v = decoded_record.pop(self.Config["attrusername"], None) if v is not None: ret["username"] = v else: # This is fallback, since we need a username on various places ret["username"] = dn - v = decoded_result.pop("cn", None) + v = decoded_record.pop("cn", None) if v is not None: ret["full_name"] = v - v = decoded_result.pop("mail", None) + v = decoded_record.pop("mail", None) if v is not None: ret["email"] = v - v = decoded_result.pop("mobile", None) + v = decoded_record.pop("mobile", None) if v is not None: ret["phone"] = v - v = decoded_result.pop("userAccountControl", None) + v = decoded_record.pop("userAccountControl", None) if v is not None: # userAccountControl is an array of binary flags returned as a decimal integer # byte #1 is ACCOUNTDISABLE which corresponds to "suspended" status @@ -297,24 +296,24 @@ def _normalize_credentials(self, dn: str, search_result: typing.Mapping): except ValueError: pass - v = decoded_result.pop("createTimestamp", None) + v = decoded_record.pop("createTimestamp", None) if v is not None: ret["_c"] = _parse_timestamp(v) else: - v = decoded_result.pop("createTimeStamp", None) + v = decoded_record.pop("createTimeStamp", None) if v is not None: ret["_c"] = _parse_timestamp(v) - v = decoded_result.pop("modifyTimestamp", None) + v = decoded_record.pop("modifyTimestamp", None) if v is not None: ret["_m"] = _parse_timestamp(v) else: - v = decoded_result.pop("modifyTimeStamp", None) + v = decoded_record.pop("modifyTimeStamp", None) if v is not None: ret["_m"] = _parse_timestamp(v) - if len(decoded_result) > 0: - ret["_ldap"] = decoded_result + if len(decoded_record) > 0: + ret["_ldap"] = decoded_record return ret @@ -389,12 +388,12 @@ def _format_credentials_id(self, dn): def _build_search_filter(self, filtr: typing.Optional[str] = None): if not filtr: - filterstr = self.Config["filter"] + filterstr = self.Filter else: # The query filter is the intersection of the filter from config # and the filter defined by the search request # The username must START WITH the given filter string - filter_template = "(&{}({}=*%s*))".format(self.Config["filter"], self.Config["attrusername"]) + filter_template = "(&{}({}=*%s*))".format(self.Filter, self.Config["attrusername"]) assertion_values = ["{}".format(filtr.lower())] filterstr = ldap.filter.filter_format( filter_template=filter_template, From 96c396ce0e2757068643678d7b249c6bc593e8ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 12:34:35 +0200 Subject: [PATCH 3/9] more refactoring --- seacatauth/credentials/providers/ldap.py | 174 +++++++++++------------ 1 file changed, 86 insertions(+), 88 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 05d15507..bf37d7b5 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -13,6 +13,7 @@ import asab import asab.proactor +import asab.config from .abc import CredentialsProviderABC @@ -94,13 +95,17 @@ def __init__(self, provider_id, config_section_name, proactor_svc): self.IdentFields.append(self.Config["attrusername"]) - async def get(self, credentials_id, include=None) -> Optional[dict]: + async def get(self, credentials_id: str, include: typing.Optional[typing.Iterable[str]] = None) -> Optional[dict]: if not credentials_id.startswith(self.Prefix): raise KeyError("Credentials {!r} not found".format(credentials_id)) - return await self.ProactorService.execute(self._get_worker, credentials_id, include) + cn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") + try: + return await self.ProactorService.execute(self._get_worker, cn) + except KeyError as e: + raise KeyError("Credentials not found: {!r}".format(credentials_id)) from e - async def search(self, filter: dict = None, **kwargs) -> list: + async def search(self, filter: dict = None, sort: dict = None, page: int = 0, limit: int = 0, **kwargs) -> list: # TODO: Implement pagination filterstr = self._build_search_filter(filter) return await self.ProactorService.execute(self._search_worker, filterstr) @@ -123,10 +128,12 @@ async def locate(self, ident: str, ident_fields: dict = None, login_dict: dict = async def authenticate(self, credentials_id: str, credentials: dict) -> bool: - return await self.ProactorService.execute(self._authenticate_worker, credentials_id, credentials) + dn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") + password = credentials.get("password") + return await self.ProactorService.execute(self._authenticate_worker, dn, password) - async def get_login_descriptors(self, credentials_id): + async def get_login_descriptors(self, credentials_id: str) -> typing.List[typing.Dict]: # Only login with password is supported return [{ "id": "default", @@ -138,28 +145,46 @@ async def get_login_descriptors(self, credentials_id): }] - def _get_worker(self, credentials_id, include=None) -> Optional[dict]: - cn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") - with self._ldap_client() as lc: + @contextlib.contextmanager + def _ldap_client(self): + ldap_client = _LDAPObject(self.LdapUri) + ldap_client.protocol_version = ldap.VERSION3 + ldap_client.set_option(ldap.OPT_REFERRALS, 0) + + network_timeout = self.Config.getint("network_timeout") + ldap_client.set_option(ldap.OPT_NETWORK_TIMEOUT, network_timeout) + + if self.LdapUri.startswith("ldaps"): + _enable_tls(ldap_client, self.Config) + + ldap_client.simple_bind_s(self.Config["username"], self.Config["password"]) + try: + yield ldap_client + finally: + ldap_client.unbind_s() + + + def _get_worker(self, cn: str) -> Optional[dict]: + with self._ldap_client() as ldap_client: try: - results = lc.search_s( + results = ldap_client.search_s( cn, ldap.SCOPE_BASE, filterstr=self.Filter, attrlist=self.AttrList, ) except ldap.NO_SUCH_OBJECT as e: - raise KeyError("Credentials {!r} not found".format(credentials_id)) from e + raise KeyError("CN matched no LDAP objects.") from e if len(results) > 1: - L.exception("Multiple credentials matched ID.", struct_data={"cid": credentials_id}) - raise KeyError("Credentials {!r} not found".format(credentials_id)) + L.error("CN matched multiple LDAP objects.", struct_data={"CN": cn}) + raise KeyError("CN matched multiple LDAP objects.") dn, entry = results[0] return self._normalize_credentials(dn, entry) - def _search_worker(self, filterstr): + def _search_worker(self, filterstr: str): # TODO: sorting results = [] @@ -180,7 +205,7 @@ def _search_worker(self, filterstr): return results - def _count_worker(self, filterstr): + def _count_worker(self, filterstr: str): count = 0 with self._ldap_client() as ldap_client: msgid = ldap_client.search( @@ -200,7 +225,7 @@ def _count_worker(self, filterstr): return count - def _locate_worker(self, ident: str, ident_fields: typing.Optional[typing.Mapping[str, str]]): + def _locate_worker(self, ident: str, ident_fields: typing.Optional[typing.Mapping[str, str]] = None): # TODO: Implement configurable ident_fields support with self._ldap_client() as ldap_client: msgid = ldap_client.search( @@ -224,23 +249,18 @@ def _locate_worker(self, ident: str, ident_fields: typing.Optional[typing.Mappin return None - def _authenticate_worker(self, credentials_id: str, credentials: dict) -> bool: - password = credentials.get("password") - dn = base64.urlsafe_b64decode(credentials_id[len(self.Prefix):]).decode("utf-8") - + def _authenticate_worker(self, dn: str, password: str) -> bool: ldap_client = _LDAPObject(self.LdapUri) ldap_client.protocol_version = ldap.VERSION3 ldap_client.set_option(ldap.OPT_REFERRALS, 0) - # Enable TLS if self.LdapUri.startswith("ldaps"): - self._enable_tls(ldap_client) + _enable_tls(ldap_client, self.Config) try: ldap_client.simple_bind_s(dn, password) except ldap.INVALID_CREDENTIALS: - L.log(asab.LOG_NOTICE, "Authentication failed: Invalid LDAP credentials.", struct_data={ - "cid": credentials_id, "dn": dn}) + L.log(asab.LOG_NOTICE, "Authentication failed: Invalid LDAP credentials.", struct_data={"dn": dn}) return False ldap_client.unbind_s() @@ -318,70 +338,6 @@ def _normalize_credentials(self, dn: str, search_record: typing.Mapping): return ret - @contextlib.contextmanager - def _ldap_client(self): - ldap_client = _LDAPObject(self.LdapUri) - ldap_client.protocol_version = ldap.VERSION3 - ldap_client.set_option(ldap.OPT_REFERRALS, 0) - - network_timeout = self.Config.getint("network_timeout") - ldap_client.set_option(ldap.OPT_NETWORK_TIMEOUT, network_timeout) - - # Enable TLS - if self.LdapUri.startswith("ldaps"): - self._enable_tls(ldap_client) - - ldap_client.simple_bind_s(self.Config["username"], self.Config["password"]) - - try: - yield ldap_client - - finally: - ldap_client.unbind_s() - - - def _enable_tls(self, ldap_client): - tls_cafile = self.Config["tls_cafile"] - - # Add certificate authority - if len(tls_cafile) > 0: - ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, tls_cafile) - - # Set cert policy - if self.Config["tls_require_cert"] == "never": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) - elif self.Config["tls_require_cert"] == "demand": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - elif self.Config["tls_require_cert"] == "allow": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) - elif self.Config["tls_require_cert"] == "hard": - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_HARD) - else: - L.error("Invalid 'tls_require_cert' value: {!r}. Defaulting to 'demand'.".format( - self.Config["tls_require_cert"] - )) - ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - - # Misc TLS options - tls_protocol_min = self.Config["tls_protocol_min"] - if tls_protocol_min != "": - if tls_protocol_min not in _TLS_VERSION: - raise ValueError("'tls_protocol_min' must be one of {} or empty.".format(list(_TLS_VERSION))) - ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, _TLS_VERSION[tls_protocol_min]) - - tls_protocol_max = self.Config["tls_protocol_max"] - if tls_protocol_max != "": - if tls_protocol_max not in _TLS_VERSION: - raise ValueError("'tls_protocol_max' must be one of {} or empty.".format(list(_TLS_VERSION))) - ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MAX, _TLS_VERSION[tls_protocol_max]) - - if self.Config["tls_cipher_suite"] != "": - ldap_client.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, self.Config["tls_cipher_suite"]) - - # NEWCTX needs to be the last option, because it applies all the prepared options to the new context - ldap_client.set_option(ldap.OPT_X_TLS_NEWCTX, 0) - - def _format_credentials_id(self, dn): return self.Prefix + base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii") @@ -393,7 +349,7 @@ def _build_search_filter(self, filtr: typing.Optional[str] = None): # The query filter is the intersection of the filter from config # and the filter defined by the search request # The username must START WITH the given filter string - filter_template = "(&{}({}=*%s*))".format(self.Filter, self.Config["attrusername"]) + filter_template = "(&{}({}=%s*))".format(self.Filter, self.Config["attrusername"]) assertion_values = ["{}".format(filtr.lower())] filterstr = ldap.filter.filter_format( filter_template=filter_template, @@ -422,3 +378,45 @@ def _prepare_attributes(config: typing.Mapping): attr.add("cn") attr.add(config["attrusername"]) return list(attr) + + +def _enable_tls(ldap_client, config: typing.Mapping): + tls_cafile = config["tls_cafile"] + + # Add certificate authority + if len(tls_cafile) > 0: + ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, tls_cafile) + + # Set cert policy + if config["tls_require_cert"] == "never": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) + elif config["tls_require_cert"] == "demand": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + elif config["tls_require_cert"] == "allow": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) + elif config["tls_require_cert"] == "hard": + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_HARD) + else: + L.error("Invalid 'tls_require_cert' value: {!r}. Defaulting to 'demand'.".format( + config["tls_require_cert"] + )) + ldap_client.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + + # Misc TLS options + tls_protocol_min = config["tls_protocol_min"] + if tls_protocol_min != "": + if tls_protocol_min not in _TLS_VERSION: + raise ValueError("'tls_protocol_min' must be one of {} or empty.".format(list(_TLS_VERSION))) + ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, _TLS_VERSION[tls_protocol_min]) + + tls_protocol_max = config["tls_protocol_max"] + if tls_protocol_max != "": + if tls_protocol_max not in _TLS_VERSION: + raise ValueError("'tls_protocol_max' must be one of {} or empty.".format(list(_TLS_VERSION))) + ldap_client.set_option(ldap.OPT_X_TLS_PROTOCOL_MAX, _TLS_VERSION[tls_protocol_max]) + + if config["tls_cipher_suite"] != "": + ldap_client.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, config["tls_cipher_suite"]) + + # NEWCTX needs to be the last option, because it applies all the prepared options to the new context + ldap_client.set_option(ldap.OPT_X_TLS_NEWCTX, 0) From 2665715ca8b48e3127c2d8c02389b6de40afa3dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 17:36:36 +0200 Subject: [PATCH 4/9] more refactoring --- seacatauth/credentials/providers/ldap.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index bf37d7b5..1c531856 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -164,7 +164,7 @@ def _ldap_client(self): ldap_client.unbind_s() - def _get_worker(self, cn: str) -> Optional[dict]: + def _get_worker(self, cn: str) -> Optional[typing.Dict]: with self._ldap_client() as ldap_client: try: results = ldap_client.search_s( @@ -184,10 +184,9 @@ def _get_worker(self, cn: str) -> Optional[dict]: return self._normalize_credentials(dn, entry) - def _search_worker(self, filterstr: str): - # TODO: sorting + def _search_worker(self, filterstr: str) -> typing.List[typing.Dict]: + # TODO: Implement sorting (Note that not all LDAP servers support server-side sorting) results = [] - with self._ldap_client() as ldap_client: msgid = ldap_client.search( self.Base, @@ -205,7 +204,7 @@ def _search_worker(self, filterstr: str): return results - def _count_worker(self, filterstr: str): + def _count_worker(self, filterstr: str) -> int: count = 0 with self._ldap_client() as ldap_client: msgid = ldap_client.search( @@ -225,7 +224,11 @@ def _count_worker(self, filterstr: str): return count - def _locate_worker(self, ident: str, ident_fields: typing.Optional[typing.Mapping[str, str]] = None): + def _locate_worker( + self, + ident: str, + ident_fields: typing.Optional[typing.Mapping[str, str]] = None + ) -> typing.Optional[str]: # TODO: Implement configurable ident_fields support with self._ldap_client() as ldap_client: msgid = ldap_client.search( @@ -268,7 +271,7 @@ def _authenticate_worker(self, dn: str, password: str) -> bool: return True - def _normalize_credentials(self, dn: str, search_record: typing.Mapping): + def _normalize_credentials(self, dn: str, search_record: typing.Mapping) -> typing.Dict: ret = { "_id": self._format_credentials_id(dn), "_type": self.Type, @@ -338,11 +341,11 @@ def _normalize_credentials(self, dn: str, search_record: typing.Mapping): return ret - def _format_credentials_id(self, dn): + def _format_credentials_id(self, dn: str) -> str: return self.Prefix + base64.urlsafe_b64encode(dn.encode("utf-8")).decode("ascii") - def _build_search_filter(self, filtr: typing.Optional[str] = None): + def _build_search_filter(self, filtr: typing.Optional[str] = None) -> str: if not filtr: filterstr = self.Filter else: @@ -371,7 +374,7 @@ def _parse_timestamp(ts: str) -> datetime.datetime: return datetime.datetime.strptime(ts, r"%Y%m%d%H%M%S.%fZ") -def _prepare_attributes(config: typing.Mapping): +def _prepare_attributes(config: typing.Mapping) -> list: attr = set(config["attributes"].split(" ")) attr.add("createTimestamp") attr.add("modifyTimestamp") From 2e5d54aeca646bd24e0213640d9645e1576c39fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 17:37:05 +0200 Subject: [PATCH 5/9] include the rest in custom data --- seacatauth/credentials/providers/ldap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 1c531856..94f5453a 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -336,7 +336,7 @@ def _normalize_credentials(self, dn: str, search_record: typing.Mapping) -> typi ret["_m"] = _parse_timestamp(v) if len(decoded_record) > 0: - ret["_ldap"] = decoded_record + ret["data"] = {k: v for k, v in decoded_record.items() if k in self.AttrList} return ret From 29f9d9404aab4ca735434ae193af57c28c0c8303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 17:37:25 +0200 Subject: [PATCH 6/9] more info by default --- seacatauth/credentials/providers/ldap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 94f5453a..42054a6f 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -54,8 +54,8 @@ class LDAPCredentialsProvider(CredentialsProviderABC): "username": "cn=admin,dc=example,dc=org", "password": "admin", "base": "dc=example,dc=org", - "filter": "(&(objectClass=inetOrgPerson)(cn=*))", # should filter valid users only - "attributes": "mail mobile", + "filter": "(&(objectClass=inetOrgPerson)(cn=*))", + "attributes": "mail mobile userAccountControl displayName", # Path to CA file in PEM format "tls_cafile": "", @@ -72,7 +72,7 @@ class LDAPCredentialsProvider(CredentialsProviderABC): "tls_protocol_max": "", "tls_cipher_suite": "", - "attrusername": "cn", # LDAP attribute that should be used as a username, e.g. `uid` or `sAMAccountName` + "attrusername": "sAMAccountName", # LDAP attribute that should be used as a username, e.g. `uid` or `sAMAccountName` } From 5be7202bfe904b3534e903af9c28cb336f7567ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 17:46:16 +0200 Subject: [PATCH 7/9] flake8 --- seacatauth/credentials/providers/ldap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 42054a6f..35ae1773 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -119,7 +119,7 @@ async def count(self, filtr=None) -> int: async def iterate(self, offset: int = 0, limit: int = -1, filtr: str = None): filterstr = self._build_search_filter(filtr) results = await self.ProactorService.execute(self._search_worker, filterstr) - for i in results[offset : (None if limit == -1 else limit + offset)]: + for i in results[offset:(None if limit == -1 else limit + offset)]: yield i @@ -280,7 +280,7 @@ def _normalize_credentials(self, dn: str, search_record: typing.Mapping) -> typi decoded_record = {"dn": dn} for k, v in search_record.items(): - if k =="userPassword": + if k == "userPassword": continue if isinstance(v, list): if len(v) == 0: From 38c7399d4a8a4ded7c05a3041f804848b85328a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Mon, 16 Sep 2024 17:50:42 +0200 Subject: [PATCH 8/9] update CHANGELOG.md --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fe0c6d1e..ec1aaeb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v24.36 ### Pre-releases +- v24.36-alpha5 - v24.36-alpha4 - v24.36-alpha3 - ~~v24.36-alpha2~~ @@ -18,9 +19,13 @@ - Hotfix: Session expiration in userinfo must match access token expiration (#414, `v24.29-alpha7`) ### Features +- Improve the default configuration of LDAP credentials provider (#422, `v24.36-alpha5`) - Duplicating roles (#416, `v24.36-alpha1`) - Run Batman with warning when there is no ElasticSearch URL (#413, `v24.29-alpha6`) +### Refactoring +- Refactor LDAP credentials provider (#422, `v24.36-alpha5`) + --- From 65f566dbf6dcbed27a373e97a926a7a1ae2cb93b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Hru=C5=A1ka?= Date: Tue, 17 Sep 2024 14:19:00 +0200 Subject: [PATCH 9/9] default filter to match different AD types --- seacatauth/credentials/providers/ldap.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/seacatauth/credentials/providers/ldap.py b/seacatauth/credentials/providers/ldap.py index 35ae1773..1d476273 100644 --- a/seacatauth/credentials/providers/ldap.py +++ b/seacatauth/credentials/providers/ldap.py @@ -54,7 +54,7 @@ class LDAPCredentialsProvider(CredentialsProviderABC): "username": "cn=admin,dc=example,dc=org", "password": "admin", "base": "dc=example,dc=org", - "filter": "(&(objectClass=inetOrgPerson)(cn=*))", + "filter": "|(objectClass=organizationalPerson)(objectClass=inetOrgPerson)", "attributes": "mail mobile userAccountControl displayName", # Path to CA file in PEM format @@ -85,7 +85,9 @@ def __init__(self, provider_id, config_section_name, proactor_svc): self.LdapUri = self.Config["uri"] self.Base = self.Config["base"] - self.Filter = self.Config["filter"] + self.Filter: str = self.Config["filter"] + if not (self.Filter.startswith("(") and self.Filter.endswith(")")): + self.Filter = "({})".format(self.Filter) self.AttrList = _prepare_attributes(self.Config) # Fields to filter by when locating a user