ldap.LdapServer: Add encode/decode helpers and parameters

This commit is contained in:
Benjamin Renard 2022-06-23 18:34:48 +02:00
parent fe3e3ed5f4
commit 2bc9964b12

View file

@ -21,6 +21,34 @@ log = logging.getLogger(__name__)
DEFAULT_ENCODING = 'utf-8'
def decode_ldap_value(value, encoding='utf-8'):
""" Decoding LDAP attribute values helper """
if isinstance(value, bytes):
return value.decode_ldap_value(encoding)
if isinstance(value, list):
return [decode_ldap_value(v) for v in value]
if isinstance(value, dict):
return dict(
(key, decode_ldap_value(values))
for key, values in value.items()
)
return value
def encode_ldap_value(value, encoding='utf-8'):
""" Encoding LDAP attribute values helper """
if isinstance(value, str):
return value.encode_ldap_value(encoding)
if isinstance(value, list):
return [encode_ldap_value(v) for v in value]
if isinstance(value, dict):
return dict(
(key, encode_ldap_value(values))
for key, values in value.items()
)
return value
class LdapServer:
""" LDAP server connection helper """ # pylint: disable=useless-object-inheritance
@ -182,9 +210,9 @@ class LdapServer:
self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize)
return ret
def add_object(self, dn, attrs):
def add_object(self, dn, attrs, encode=False):
""" Add an object in LDAP directory """
ldif = modlist.addModlist(attrs)
ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs)
assert self.con or self.connect()
try:
self.logger.debug("LdapServer - Add %s", dn)
@ -195,11 +223,12 @@ class LdapServer:
return False
def update_object(self, dn, old, new, ignore_attrs=None, relax=False):
def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False):
""" Update an object in LDAP directory """
assert not relax or not self.v2, "Relax modification is not available on LDAP version 2"
ldif = modlist.modifyModlist(
old, new,
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
if ldif == []:
@ -216,10 +245,11 @@ class LdapServer:
return False
@staticmethod
def update_need(old, new, ignore_attrs=None):
def update_need(old, new, ignore_attrs=None, encode=False):
""" Check if an update is need on a LDAP object based on its old and new attributes values """
ldif = modlist.modifyModlist(
old, new,
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
if ldif == []:
@ -227,19 +257,24 @@ class LdapServer:
return True
@staticmethod
def get_changes(old, new, ignore_attrs=None):
def get_changes(old, new, ignore_attrs=None, encode=False):
""" Retrieve changes (as modlist) on an object based on its old and new attributes values """
return modlist.modifyModlist(
old, new,
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
@staticmethod
def format_changes(old, new, ignore_attrs=None, prefix=None):
def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False):
""" Format changes (modlist) on an object based on its old and new attributes values to display/log it """
msg = []
prefix = prefix if prefix else ''
for (op, attr, val) in modlist.modifyModlist(old, new, ignore_attr_types=ignore_attrs if ignore_attrs else []):
for (op, attr, val) in modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
):
if op == ldap.MOD_ADD: # pylint: disable=no-member
op = 'ADD'
elif op == ldap.MOD_DELETE: # pylint: disable=no-member
@ -309,19 +344,19 @@ class LdapServer:
return obj[0][0]
@staticmethod
def get_attr(obj, attr, all=None, default=None):
def get_attr(obj, attr, all_values=None, default=None, decode=False):
""" Retreive an on object attribute value(s) from the object entry in LDAP search result """
if attr not in obj:
for k in obj:
if k.lower() == attr.lower():
attr = k
break
if all is not None:
if all_values:
if attr in obj:
return obj[attr]
return decode_ldap_value(obj[attr]) if decode else obj[attr]
return default or []
if attr in obj:
return obj[attr][0]
return decode_ldap_value(obj[attr][0]) if decode else obj[attr][0]
return default
@ -456,7 +491,7 @@ class LdapClient:
"""
obj = dict(dn=dn)
for attr in attrs:
obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all=True)]
obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all_values=True)]
return obj
@staticmethod