# -*- coding: utf-8 -*- """ LDAP server connection helper """ import copy import datetime import logging import dateutil.parser import dateutil.tz import ldap from ldap.controls import SimplePagedResultsControl from ldap.controls.simple import RelaxRulesControl import ldap.modlist as modlist import pytz from mylib import pretty_format_dict log = logging.getLogger(__name__) class LdapServer: """ LDAP server connection helper """ # pylint: disable=useless-object-inheritance uri = None dn = None pwd = None v2 = None con = 0 def __init__(self, uri, dn=None, pwd=None, v2=None, raiseOnError=False, logger=False): self.uri = uri self.dn = dn self.pwd = pwd self.raiseOnError = raiseOnError if v2: self.v2 = True if logger: self.logger = logger else: self.logger = logging.getLogger(__name__) def _error(self, error, level=logging.WARNING): if self.raiseOnError: raise LdapServerException(error) self.logger.log(level, error) def connect(self): """ Start connection to LDAP server """ if self.con == 0: try: con = ldap.initialize(self.uri) if self.v2: con.protocol_version = ldap.VERSION2 # pylint: disable=no-member else: con.protocol_version = ldap.VERSION3 # pylint: disable=no-member if self.dn: con.simple_bind_s(self.dn, self.pwd) elif self.uri.startswith('ldapi://'): con.sasl_interactive_bind_s("", ldap.sasl.external()) self.con = con return True except ldap.LDAPError as e: # pylint: disable=no-member self._error('LdapServer - Error connecting and binding to LDAP server : %s' % e, logging.CRITICAL) return False return True @staticmethod def get_scope(scope): """ Map scope parameter to python-ldap value """ if scope == 'base': return ldap.SCOPE_BASE # pylint: disable=no-member if scope == 'one': return ldap.SCOPE_ONELEVEL # pylint: disable=no-member if scope == 'sub': return ldap.SCOPE_SUBTREE # pylint: disable=no-member raise Exception("Unknown LDAP scope '%s'" % scope) def search(self, basedn, filterstr=None, attrs=None, sizelimit=0, scope=None): """ Run a search on LDAP server """ res_id = self.con.search( basedn, self.get_scope(scope if scope else 'sub'), filterstr if filterstr else '(objectClass=*)', attrs if attrs else [] ) ret = {} c = 0 while True: res_type, res_data = self.con.result(res_id, 0) if res_data == [] or (sizelimit and c > sizelimit): break if res_type == ldap.RES_SEARCH_ENTRY: # pylint: disable=no-member ret[res_data[0][0]] = res_data[0][1] c += 1 return ret def get_object(self, dn, filterstr=None, attrs=None): """ Retrieve a LDAP object specified by its DN """ result = self.search(dn, filterstr=filterstr, scope='base', attrs=attrs) return result[dn] if dn in result else None def paged_search(self, basedn, filterstr, attrs, scope='sub', pagesize=500): """ Run a paged search on LDAP server """ assert not self.v2, "Paged search is not available on LDAP version 2" # Initialize SimplePagedResultsControl object page_control = SimplePagedResultsControl( True, size=pagesize, cookie='' # Start without cookie ) ret = {} pages_count = 0 self.logger.debug( "LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d and attrs=%s", basedn, filterstr, scope, pagesize, attrs ) while True: pages_count += 1 self.logger.debug( "LdapServer - Paged search: request page %d with a maximum of %d objects (current total count: %d)", pages_count, pagesize, len(ret) ) try: res_id = self.con.search_ext( basedn, self.get_scope(scope), filterstr, attrs, serverctrls=[page_control] ) except ldap.LDAPError as e: # pylint: disable=no-member self._error('LdapServer - Error running paged search on LDAP server: %s' % e, logging.CRITICAL) return False try: rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable except ldap.LDAPError as e: # pylint: disable=no-member self._error('LdapServer - Error pulling paged search result from LDAP server: %s' % e, logging.CRITICAL) return False # Detect and catch PagedResultsControl answer from rctrls result_page_control = None if rctrls: for rctrl in rctrls: if rctrl.controlType == SimplePagedResultsControl.controlType: result_page_control = rctrl break # If PagedResultsControl answer not detected, paged serach if not result_page_control: self._error('LdapServer - Server ignores RFC2696 control, paged search can not works', logging.CRITICAL) return False # Store results of this page for obj_dn, obj_attrs in rdata: ret[obj_dn] = obj_attrs # If no cookie returned, we are done if not result_page_control.cookie: break # Otherwise, set cookie for the next search page_control.cookie = result_page_control.cookie self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize) return ret def add_object(self, dn, attrs): """ Add an object in LDAP directory """ ldif = modlist.addModlist(attrs) try: self.logger.debug("LdapServer - Add %s", dn) self.con.add_s(dn, ldif) return True except ldap.LDAPError as e: # pylint: disable=no-member self._error("LdapServer - Error adding %s : %s" % (dn, e), logging.ERROR) return False def update_object(self, dn, old, new, ignore_attrs=None, relax=False): """ Update an object in LDAP directory """ assert not relax or not self.v2, "Relax modification is not available on LDAP version 2" ldif = modlist.modifyModlist( old, new, ignore_attr_types=ignore_attrs if ignore_attrs else [] ) if ldif == []: return True try: if relax: self.con.modify_ext_s(dn, ldif, serverctrls=[RelaxRulesControl()]) else: self.con.modify_s(dn, ldif) return True except ldap.LDAPError as e: # pylint: disable=no-member self._error("LdapServer - Error updating %s : %s\nOld : %s\nNew : %s" % (dn, e, old, new), logging.ERROR) return False @staticmethod def update_need(old, new, ignore_attrs=None): """ Check if an update is need on a LDAP object based on its old and new attributes values """ ldif = modlist.modifyModlist( old, new, ignore_attr_types=ignore_attrs if ignore_attrs else [] ) if ldif == []: return False return True @staticmethod def get_changes(old, new, ignore_attrs=None): """ Retrieve changes (as modlist) on an object based on its old and new attributes values """ return modlist.modifyModlist( old, new, ignore_attr_types=ignore_attrs if ignore_attrs else [] ) @staticmethod def format_changes(old, new, ignore_attrs=None, prefix=None): """ Format changes (modlist) on an object based on its old and new attributes values to display/log it """ msg = [] for (op, attr, val) in modlist.modifyModlist(old, new, ignore_attr_types=ignore_attrs if ignore_attrs else []): if op == ldap.MOD_ADD: # pylint: disable=no-member op = 'ADD' elif op == ldap.MOD_DELETE: # pylint: disable=no-member op = 'DELETE' elif op == ldap.MOD_REPLACE: # pylint: disable=no-member op = 'REPLACE' else: op = 'UNKNOWN (=%s)' % op if val is None and op == 'DELETE': msg.append('%s - %s %s' % (prefix if prefix else '', op, attr)) else: msg.append('%s - %s %s: %s' % (prefix, op, attr, val)) return '\n'.join(msg) def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True): """ Rename an object in LDAP directory """ # If new_rdn is a complete DN, split new RDN and new superior DN if len(new_rdn.split(',')) > 1: self.logger.debug( "LdapServer - Rename with a full new DN detected (%s): split new RDN and new superior DN", new_rdn ) assert new_sup is None, "You can't provide a complete DN as new_rdn and also provide new_sup parameter" new_dn_parts = new_rdn.split(',') new_sup = ','.join(new_dn_parts[1:]) new_rdn = new_dn_parts[0] try: self.logger.debug( "LdapServer - Rename %s in %s (new superior: %s, delete old: %s)", dn, new_rdn, "same" if new_sup is None else new_sup, delete_old ) self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old) return True except ldap.LDAPError as e: # pylint: disable=no-member self._error( "LdapServer - Error renaming %s in %s (new superior: %s, delete old: %s): %s" % ( dn, new_rdn, "same" if new_sup is None else new_sup, delete_old, e ), logging.ERROR ) return False def drop_object(self, dn): """ Drop an object in LDAP directory """ try: self.logger.debug("LdapServer - Delete %s", dn) self.con.delete_s(dn) return True except ldap.LDAPError as e: # pylint: disable=no-member self._error("LdapServer - Error deleting %s : %s" % (dn, e), logging.ERROR) return False @staticmethod def get_dn(obj): """ Retreive an on object DN from its entry in LDAP search result """ return obj[0][0] @staticmethod def get_attr(obj, attr, all=None, default=None): """ Retreive an on object attribute value(s) from the object entry in LDAP search result """ if attr not in obj: for k in obj: if k.lower() == attr.lower(): attr = k break if all is not None: if attr in obj: return obj[attr] return default or [] if attr in obj: return obj[attr][0] return default class LdapServerException(BaseException): """ Generic exception raised by LdapServer """ def __init__(self, msg): BaseException.__init__(self, msg) class LdapClientException(LdapServerException): """ Generic exception raised by LdapServer """ def __init__(self, msg): LdapServerException.__init__(self, msg) class LdapClient: """ LDAP Client (based on python-mylib.LdapServer) """ options = {} # Cache objects _cached_objects = dict() def __init__(self, options): self.options = options log.info("Connect to LDAP server %s as %s", options.ldap_uri, options.ldap_binddn) self.cnx = LdapServer(options.ldap_uri, dn=options.ldap_binddn, pwd=options.ldap_bindpwd, raiseOnError=True) self.cnx.connect() @classmethod def decode(cls, value): if isinstance(value, list): return [cls.decode(v) for v in value] if isinstance(value, str): return value return value.decode('utf-8', 'ignore') @classmethod def encode(cls, value): if isinstance(value, list): return [cls.encode(v) for v in value] if isinstance(value, bytes): return value return value.encode('utf-8') def get_attrs(self, dn, attrs): obj = dict(dn=dn) for attr in attrs: obj[attr] = [self.decode(v) for v in self.cnx.get_attr(attrs, attr, all=True)] return obj @staticmethod def get_attr(obj, attr, default="", all_values=False): vals = obj.get(attr, []) if vals: return vals if all_values else vals[0] return default if default or not all_values else [] def get_objects(self, name, filterstr, basedn, attrs, key_attr=None, warn=True): if name in self._cached_objects: log.debug('Retreived %s objects from cache', name) else: log.debug('Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn) ldap_data = self.cnx.search( basedn=basedn, filterstr=filterstr, attrs=attrs ) if not ldap_data: if warn: log.warning('No %s found in LDAP', name) else: log.debug('No %s found in LDAP', name) return {} objects = {} for obj_dn, obj_attrs in ldap_data.items(): objects[obj_dn] = self.get_attrs(obj_dn, obj_attrs) self._cached_objects[name] = objects if not key_attr or key_attr == 'dn': return self._cached_objects[name] return dict( (self.get_attr(self._cached_objects[name][dn], key_attr), self._cached_objects[name][dn]) for dn in self._cached_objects[name] ) def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True): log.debug('Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")', type_name, object_name, filterstr, basedn) ldap_data = self.cnx.search( basedn=basedn, filterstr=filterstr, attrs=attrs ) if not ldap_data: if warn: log.warning('No %s "%s" found in LDAP', type_name, object_name) else: log.debug('No %s "%s" found in LDAP', type_name, object_name) return None if len(ldap_data) > 1: raise LdapClientException('More than one %s "%s": %s' % (type_name, object_name, ' / '.join(ldap_data.keys()))) dn = next(iter(ldap_data)) return self.get_attrs(dn, ldap_data[dn]) def get_object_by_dn(self, type_name, dn, populate_cache_method=None, warn=True): if type_name not in self._cached_objects: if not populate_cache_method: return False populate_cache_method() if type_name not in self._cached_objects: if warn: log.warning('No %s found in LDAP', type_name) else: log.debug('No %s found in LDAP', type_name) return None if dn not in self._cached_objects[type_name]: if warn: log.warning('No %s found with DN "%s"', type_name, dn) else: log.debug('No %s found with DN "%s"', type_name, dn) return None return self._cached_objects[type_name][dn] @classmethod def object_attr_mached(cls, obj, attr, value, case_sensitive=False): if case_sensitive: return value in cls.get_attr(obj, attr, all_values=True) return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)] def get_object_by_attr(self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True): if type_name not in self._cached_objects: if not populate_cache_method: return False populate_cache_method() if type_name not in self._cached_objects: if warn: log.warning('No %s found in LDAP', type_name) else: log.debug('No %s found in LDAP', type_name) return None matched = dict( (dn, obj) for dn, obj in self._cached_objects[type_name].items() if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive) ) if not matched: if warn: log.warning('No %s found with %s="%s"', type_name, attr, value) else: log.debug('No %s found with %s="%s"', type_name, attr, value) return None if len(matched) > 1: raise LdapClientException('More than one %s with %s="%s" found: %s' % (type_name, attr, value , ' / '.join(matched.keys()))) dn = next(iter(matched)) return matched[dn] def get_changes(self, ldap_obj, attrs, protected_attrs=None): """ Retrieve changes on a LDAP object :param ldap_obj: The original LDAP object :param attrs: The new LDAP object attributes values :param protected_attrs: Optional list of protected attributes """ old = {} new = {} protected_attrs = [a.lower() for a in protected_attrs or list()] protected_attrs.append('dn') # New/updated attributes for attr in attrs: if protected_attrs and attr.lower() in protected_attrs: continue if attr in ldap_obj and ldap_obj[attr]: if sorted(ldap_obj[attr]) == sorted(attrs[attr]): continue old[attr] = self.encode(ldap_obj[attr]) new[attr] = self.encode(attrs[attr]) # Deleted attributes for attr in ldap_obj: if (not protected_attrs or attr.lower() not in protected_attrs) and ldap_obj[attr] and attr not in attrs: old[attr] = self.encode(ldap_obj[attr]) if old == new: return None return (old, new) def add_object(self, dn, attrs): """ Add an object :param dn: The LDAP object DN :param attrs: The LDAP object attributes (as dict) """ attrs = dict( (attr, self.encode(values)) for attr, values in attrs.items() ) try: if self.options.just_try: log.debug('Just-try mode : do not really add object in LDAP') return True return self.cnx.add_object(dn, attrs) except LdapServerException: log.error( "An error occurred adding object %s in LDAP:\n%s\n", dn, pretty_format_dict(attrs), exc_info=True ) return False def update_object(self, ldap_obj, changes, protected_attrs=None, rdn_attr=None): """ Update an object :param ldap_obj: The original LDAP object :param changes: The changes to make on LDAP object (as formated by get_changes() method) :param protected_attrs: An optional list of protected attributes :param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected) """ assert isinstance(changes, (list, tuple)) and len(changes) == 2 and isinstance(changes[0], dict) and isinstance(changes[1], dict), "changes parameter must be a result of get_changes() method (%s given)" % type(changes) if not rdn_attr: rdn_attr = ldap_obj['dn'].split('=')[0] log.debug('Auto-detected RDN attribute from DN: %s => %s', ldap_obj['dn'], rdn_attr) old_rdn_values = self.get_attr(changes[0], rdn_attr, all_values=True) new_rdn_values = self.get_attr(changes[1], rdn_attr, all_values=True) if old_rdn_values or new_rdn_values: if not new_rdn_values: log.error( "%s : Attribute %s can't be deleted because it's used as RDN.", ldap_obj['dn'], rdn_attr ) return False log.debug( '%s: Changes detected on %s RDN attribute: must rename object before updating it', ldap_obj['dn'], rdn_attr ) # Compute new object DN dn_parts = ldap_obj['dn'].split(',') basedn = ','.join(dn_parts[1:]) new_rdn = '%s=%s' % (rdn_attr, new_rdn_values[0]) new_dn = '%s,%s' % (new_rdn, basedn) # Rename object log.debug('%s: Rename to %s', ldap_obj['dn'], new_dn) if not self.move_object(ldap_obj, new_rdn): return False # Remove RDN in changes list for attr in changes[0].keys(): if attr.lower() == rdn_attr.lower(): del changes[0][attr] for attr in changes[1].keys(): if attr.lower() == rdn_attr.lower(): del changes[1][attr] # Check that there are other changes if not changes[0] and not changes[1]: log.debug('%s: No other change after renaming', new_dn) return True # Otherwise, update object DN ldap_obj['dn'] = new_dn else: log.debug('%s: No change detected on RDN attibute %s', ldap_obj['dn'], rdn_attr) try: if self.options.just_try: log.debug('Just-try mode : do not really update object in LDAP') return True return self.cnx.update_object( ldap_obj['dn'], changes[0], changes[1], ignore_attrs=protected_attrs ) except LdapServerException: log.error( "An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n", ldap_obj['dn'], pretty_format_dict(changes[0]), pretty_format_dict(changes[1]), exc_info=True ) return False def move_object(self, ldap_obj, new_dn_or_rdn): """ Move/rename an object :param ldap_obj: The original LDAP object :param new_dn_or_rdn: The new LDAP object's DN (or RDN) """ try: if self.options.just_try: log.debug('Just-try mode : do not really move object in LDAP') return True return self.cnx.rename_object(ldap_obj['dn'], new_dn_or_rdn) except LdapServerException: log.error( "An error occurred moving object %s in LDAP (destination: %s)", ldap_obj['dn'], new_dn_or_rdn, exc_info=True ) return False def drop_object(self, ldap_obj): """ Drop/delete an object :param ldap_obj: The original LDAP object to delete/drop """ try: if self.options.just_try: log.debug('Just-try mode : do not really drop object in LDAP') return True return self.cnx.drop_object(ldap_obj['dn']) except LdapServerException: log.error( "An error occurred removing object %s in LDAP", ldap_obj['dn'], exc_info=True ) return False # # LDAP date string helpers # def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None): """ Convert LDAP date string to datetime.datetime object :param value: The LDAP date string to convert :param to_timezone: If specified, the return datetime will be converted to this specific timezone (optional, default : timezone of the LDAP date string) :param default_timezone: The timezone used if LDAP date string does not specified the timezone (optional, default : server local timezone) :param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP) """ assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), 'to_timezone must be None, a datetime.tzinfo object or a string (not %s)' % type(to_timezone) assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), 'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not %s)' % type(default_timezone) date = dateutil.parser.parse(value, dayfirst=False) if not date.tzinfo: if naive: return date if not default_timezone: default_timezone = pytz.utc elif default_timezone == 'local': default_timezone = dateutil.tz.tzlocal() elif isinstance(default_timezone, str): default_timezone = pytz.timezone(default_timezone) if isinstance(default_timezone, pytz.tzinfo.DstTzInfo): date = default_timezone.localize(date) elif isinstance(default_timezone, datetime.tzinfo): date = date.replace(tzinfo=default_timezone) else: raise Exception("It's not supposed to happen!") elif naive: return date.replace(tzinfo=None) if to_timezone: if to_timezone == 'local': to_timezone = dateutil.tz.tzlocal() elif isinstance(to_timezone, str): to_timezone = pytz.timezone(to_timezone) return date.astimezone(to_timezone) return date def parse_date(value, to_timezone=None, default_timezone=None, naive=True): """ Convert LDAP date string to datetime.date object :param value: The LDAP date string to convert :param to_timezone: If specified, the return datetime will be converted to this specific timezone (optional, default : timezone of the LDAP date string) :param default_timezone: The timezone used if LDAP date string does not specified the timezone (optional, default : server local timezone) :param naive: Use naive datetime : do not handle timezone conversion from LDAP """ return parse_datetime(value, to_timezone, default_timezone, naive).date() def format_datetime(value, from_timezone=None, to_timezone=None, naive=None): """ Convert datetime.datetime object to LDAP date string :param value: The datetime.datetime object to convert :param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo) (optional, default : server local timezone) :param to_timezone: The timezone used in LDAP (optional, default : UTC) :param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion) """ assert isinstance(value, datetime.datetime), 'First parameter must be an datetime.datetime object (not %s)' % type(value) assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), 'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not %s)' % type(from_timezone) assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), 'to_timezone must be None, a datetime.tzinfo object or a string (not %s)' % type(to_timezone) if not value.tzinfo and not naive: if not from_timezone or from_timezone == 'local': from_timezone = dateutil.tz.tzlocal() elif isinstance(from_timezone, str): from_timezone = pytz.timezone(from_timezone) if isinstance(from_timezone, pytz.tzinfo.DstTzInfo): from_value = from_timezone.localize(value) elif isinstance(from_timezone, datetime.tzinfo): from_value = value.replace(tzinfo=from_timezone) else: raise Exception("It's not supposed to happen!") elif naive: from_value = value.replace(tzinfo=pytz.utc) else: from_value = copy.deepcopy(value) if not to_timezone: to_timezone = pytz.utc elif to_timezone == 'local': to_timezone = dateutil.tz.tzlocal() elif isinstance(to_timezone, str): to_timezone = pytz.timezone(to_timezone) to_value = from_value.astimezone(to_timezone) if not naive else from_value datestring = to_value.strftime('%Y%m%d%H%M%S%z') if datestring.endswith('+0000'): datestring = datestring.replace('+0000', 'Z') return datestring def format_date(value, from_timezone=None, to_timezone=None, naive=True): """ Convert datetime.date object to LDAP date string :param value: The datetime.date object to convert :param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo) (optional, default : server local timezone) :param to_timezone: The timezone used in LDAP (optional, default : UTC) :param naive: Use naive datetime : do not handle timezone conversion before formating and return datetime as UTC (because LDAP required a timezone) """ assert isinstance(value, datetime.date), 'First parameter must be an datetime.date object (not %s)' % type(value) return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive)