ldap.LdapServer: Add encode/decode helpers and parameters

This commit is contained in:
Benjamin Renard 2022-06-23 18:34:48 +02:00
parent fe3e3ed5f4
commit 2bc9964b12

View file

@ -21,6 +21,34 @@ log = logging.getLogger(__name__)
DEFAULT_ENCODING = 'utf-8' DEFAULT_ENCODING = 'utf-8'
def decode_ldap_value(value, encoding='utf-8'):
""" Decoding LDAP attribute values helper """
if isinstance(value, bytes):
return value.decode_ldap_value(encoding)
if isinstance(value, list):
return [decode_ldap_value(v) for v in value]
if isinstance(value, dict):
return dict(
(key, decode_ldap_value(values))
for key, values in value.items()
)
return value
def encode_ldap_value(value, encoding='utf-8'):
""" Encoding LDAP attribute values helper """
if isinstance(value, str):
return value.encode_ldap_value(encoding)
if isinstance(value, list):
return [encode_ldap_value(v) for v in value]
if isinstance(value, dict):
return dict(
(key, encode_ldap_value(values))
for key, values in value.items()
)
return value
class LdapServer: class LdapServer:
""" LDAP server connection helper """ # pylint: disable=useless-object-inheritance """ LDAP server connection helper """ # pylint: disable=useless-object-inheritance
@ -182,9 +210,9 @@ class LdapServer:
self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize) self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize)
return ret return ret
def add_object(self, dn, attrs): def add_object(self, dn, attrs, encode=False):
""" Add an object in LDAP directory """ """ Add an object in LDAP directory """
ldif = modlist.addModlist(attrs) ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs)
assert self.con or self.connect() assert self.con or self.connect()
try: try:
self.logger.debug("LdapServer - Add %s", dn) self.logger.debug("LdapServer - Add %s", dn)
@ -195,11 +223,12 @@ class LdapServer:
return False return False
def update_object(self, dn, old, new, ignore_attrs=None, relax=False): def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False):
""" Update an object in LDAP directory """ """ Update an object in LDAP directory """
assert not relax or not self.v2, "Relax modification is not available on LDAP version 2" assert not relax or not self.v2, "Relax modification is not available on LDAP version 2"
ldif = modlist.modifyModlist( ldif = modlist.modifyModlist(
old, new, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else []
) )
if ldif == []: if ldif == []:
@ -216,10 +245,11 @@ class LdapServer:
return False return False
@staticmethod @staticmethod
def update_need(old, new, ignore_attrs=None): def update_need(old, new, ignore_attrs=None, encode=False):
""" Check if an update is need on a LDAP object based on its old and new attributes values """ """ Check if an update is need on a LDAP object based on its old and new attributes values """
ldif = modlist.modifyModlist( ldif = modlist.modifyModlist(
old, new, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else []
) )
if ldif == []: if ldif == []:
@ -227,19 +257,24 @@ class LdapServer:
return True return True
@staticmethod @staticmethod
def get_changes(old, new, ignore_attrs=None): def get_changes(old, new, ignore_attrs=None, encode=False):
""" Retrieve changes (as modlist) on an object based on its old and new attributes values """ """ Retrieve changes (as modlist) on an object based on its old and new attributes values """
return modlist.modifyModlist( return modlist.modifyModlist(
old, new, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else []
) )
@staticmethod @staticmethod
def format_changes(old, new, ignore_attrs=None, prefix=None): def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False):
""" Format changes (modlist) on an object based on its old and new attributes values to display/log it """ """ Format changes (modlist) on an object based on its old and new attributes values to display/log it """
msg = [] msg = []
prefix = prefix if prefix else '' prefix = prefix if prefix else ''
for (op, attr, val) in modlist.modifyModlist(old, new, ignore_attr_types=ignore_attrs if ignore_attrs else []): for (op, attr, val) in modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
):
if op == ldap.MOD_ADD: # pylint: disable=no-member if op == ldap.MOD_ADD: # pylint: disable=no-member
op = 'ADD' op = 'ADD'
elif op == ldap.MOD_DELETE: # pylint: disable=no-member elif op == ldap.MOD_DELETE: # pylint: disable=no-member
@ -309,19 +344,19 @@ class LdapServer:
return obj[0][0] return obj[0][0]
@staticmethod @staticmethod
def get_attr(obj, attr, all=None, default=None): def get_attr(obj, attr, all_values=None, default=None, decode=False):
""" Retreive an on object attribute value(s) from the object entry in LDAP search result """ """ Retreive an on object attribute value(s) from the object entry in LDAP search result """
if attr not in obj: if attr not in obj:
for k in obj: for k in obj:
if k.lower() == attr.lower(): if k.lower() == attr.lower():
attr = k attr = k
break break
if all is not None: if all_values:
if attr in obj: if attr in obj:
return obj[attr] return decode_ldap_value(obj[attr]) if decode else obj[attr]
return default or [] return default or []
if attr in obj: if attr in obj:
return obj[attr][0] return decode_ldap_value(obj[attr][0]) if decode else obj[attr][0]
return default return default
@ -456,7 +491,7 @@ class LdapClient:
""" """
obj = dict(dn=dn) obj = dict(dn=dn)
for attr in attrs: for attr in attrs:
obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all=True)] obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all_values=True)]
return obj return obj
@staticmethod @staticmethod