diff --git a/mylib/ldap.py b/mylib/ldap.py index 14b1d0d..cb18156 100644 --- a/mylib/ldap.py +++ b/mylib/ldap.py @@ -21,6 +21,34 @@ log = logging.getLogger(__name__) DEFAULT_ENCODING = 'utf-8' +def decode_ldap_value(value, encoding='utf-8'): + """ Decoding LDAP attribute values helper """ + if isinstance(value, bytes): + return value.decode_ldap_value(encoding) + if isinstance(value, list): + return [decode_ldap_value(v) for v in value] + if isinstance(value, dict): + return dict( + (key, decode_ldap_value(values)) + for key, values in value.items() + ) + return value + + +def encode_ldap_value(value, encoding='utf-8'): + """ Encoding LDAP attribute values helper """ + if isinstance(value, str): + return value.encode_ldap_value(encoding) + if isinstance(value, list): + return [encode_ldap_value(v) for v in value] + if isinstance(value, dict): + return dict( + (key, encode_ldap_value(values)) + for key, values in value.items() + ) + return value + + class LdapServer: """ LDAP server connection helper """ # pylint: disable=useless-object-inheritance @@ -182,9 +210,9 @@ class LdapServer: self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize) return ret - def add_object(self, dn, attrs): + def add_object(self, dn, attrs, encode=False): """ Add an object in LDAP directory """ - ldif = modlist.addModlist(attrs) + ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs) assert self.con or self.connect() try: self.logger.debug("LdapServer - Add %s", dn) @@ -195,11 +223,12 @@ class LdapServer: return False - def update_object(self, dn, old, new, ignore_attrs=None, relax=False): + def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False): """ Update an object in LDAP directory """ assert not relax or not self.v2, "Relax modification is not available on LDAP version 2" ldif = modlist.modifyModlist( - old, new, + encode_ldap_value(old) if encode else old, + encode_ldap_value(new) if encode else new, ignore_attr_types=ignore_attrs if ignore_attrs else [] ) if ldif == []: @@ -216,10 +245,11 @@ class LdapServer: return False @staticmethod - def update_need(old, new, ignore_attrs=None): + def update_need(old, new, ignore_attrs=None, encode=False): """ Check if an update is need on a LDAP object based on its old and new attributes values """ ldif = modlist.modifyModlist( - old, new, + encode_ldap_value(old) if encode else old, + encode_ldap_value(new) if encode else new, ignore_attr_types=ignore_attrs if ignore_attrs else [] ) if ldif == []: @@ -227,19 +257,24 @@ class LdapServer: return True @staticmethod - def get_changes(old, new, ignore_attrs=None): + def get_changes(old, new, ignore_attrs=None, encode=False): """ Retrieve changes (as modlist) on an object based on its old and new attributes values """ return modlist.modifyModlist( - old, new, + encode_ldap_value(old) if encode else old, + encode_ldap_value(new) if encode else new, ignore_attr_types=ignore_attrs if ignore_attrs else [] ) @staticmethod - def format_changes(old, new, ignore_attrs=None, prefix=None): + def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False): """ Format changes (modlist) on an object based on its old and new attributes values to display/log it """ msg = [] prefix = prefix if prefix else '' - for (op, attr, val) in modlist.modifyModlist(old, new, ignore_attr_types=ignore_attrs if ignore_attrs else []): + for (op, attr, val) in modlist.modifyModlist( + encode_ldap_value(old) if encode else old, + encode_ldap_value(new) if encode else new, + ignore_attr_types=ignore_attrs if ignore_attrs else [] + ): if op == ldap.MOD_ADD: # pylint: disable=no-member op = 'ADD' elif op == ldap.MOD_DELETE: # pylint: disable=no-member @@ -309,19 +344,19 @@ class LdapServer: return obj[0][0] @staticmethod - def get_attr(obj, attr, all=None, default=None): + def get_attr(obj, attr, all_values=None, default=None, decode=False): """ Retreive an on object attribute value(s) from the object entry in LDAP search result """ if attr not in obj: for k in obj: if k.lower() == attr.lower(): attr = k break - if all is not None: + if all_values: if attr in obj: - return obj[attr] + return decode_ldap_value(obj[attr]) if decode else obj[attr] return default or [] if attr in obj: - return obj[attr][0] + return decode_ldap_value(obj[attr][0]) if decode else obj[attr][0] return default @@ -456,7 +491,7 @@ class LdapClient: """ obj = dict(dn=dn) for attr in attrs: - obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all=True)] + obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all_values=True)] return obj @staticmethod