python-mylib/mylib/ldap.py

1118 lines
42 KiB
Python

""" LDAP server connection helper """
import copy
import datetime
import logging
import dateutil.parser
import dateutil.tz
import ldap
import pytz
from ldap import modlist
from ldap.controls import SimplePagedResultsControl
from ldap.controls.simple import RelaxRulesControl
from ldap.dn import escape_dn_chars, explode_dn
from mylib import pretty_format_dict
log = logging.getLogger(__name__)
DEFAULT_ENCODING = "utf-8"
def decode_ldap_value(value, encoding="utf-8"):
"""Decoding LDAP attribute values helper"""
if isinstance(value, bytes):
return value.decode(encoding)
if isinstance(value, list):
return [decode_ldap_value(v) for v in value]
if isinstance(value, dict):
return {key: decode_ldap_value(values) for key, values in value.items()}
return value
def encode_ldap_value(value, encoding="utf-8"):
"""Encoding LDAP attribute values helper"""
if isinstance(value, str):
return value.encode(encoding)
if isinstance(value, list):
return [encode_ldap_value(v) for v in value]
if isinstance(value, dict):
return {key: encode_ldap_value(values) for key, values in value.items()}
return value
class LdapServer:
"""LDAP server connection helper""" # pylint: disable=useless-object-inheritance
uri = None
dn = None
pwd = None
v2 = None
con = 0
def __init__(
self,
uri,
dn=None,
pwd=None,
v2=None,
raiseOnError=False,
logger=False,
checkCert=True,
disableReferral=False,
):
self.uri = uri
self.dn = dn
self.pwd = pwd
self.raiseOnError = raiseOnError
self.checkCert = checkCert
self.disableReferral = disableReferral
if v2:
self.v2 = True
if logger:
self.logger = logger
else:
self.logger = logging.getLogger(__name__)
def _error(self, error, level=logging.WARNING):
if self.raiseOnError:
raise LdapServerException(error)
self.logger.log(level, error)
def connect(self):
"""Start connection to LDAP server"""
if self.con == 0:
try:
if not self.checkCert:
# pylint: disable=no-member
ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
if self.disableReferral:
# pylint: disable=no-member
ldap.set_option(ldap.OPT_REFERRALS, ldap.OPT_OFF)
con = ldap.initialize(self.uri)
if self.v2:
con.protocol_version = ldap.VERSION2 # pylint: disable=no-member
else:
con.protocol_version = ldap.VERSION3 # pylint: disable=no-member
if self.dn:
con.simple_bind_s(self.dn, self.pwd)
elif self.uri.startswith("ldapi://"):
con.sasl_interactive_bind_s("", ldap.sasl.external())
self.con = con
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error connecting and binding to LDAP server: {e}",
logging.CRITICAL,
)
return False
return True
@staticmethod
def get_scope(scope):
"""Map scope parameter to python-ldap value"""
if scope == "base":
return ldap.SCOPE_BASE # pylint: disable=no-member
if scope == "one":
return ldap.SCOPE_ONELEVEL # pylint: disable=no-member
if scope == "sub":
return ldap.SCOPE_SUBTREE # pylint: disable=no-member
raise LdapServerException(f'Unknown LDAP scope "{scope}"')
def search(self, basedn, filterstr=None, attrs=None, sizelimit=None, scope=None):
"""Run a search on LDAP server"""
assert self.con or self.connect()
res_id = self.con.search(
basedn,
self.get_scope(scope if scope else "sub"),
filterstr if filterstr else "(objectClass=*)",
attrs if attrs else [],
)
ret = {}
c = 0
while True:
res_type, res_data = self.con.result(res_id, 0)
if res_data == [] or (sizelimit and c > sizelimit):
break
if res_type == ldap.RES_SEARCH_ENTRY: # pylint: disable=no-member
ret[res_data[0][0]] = res_data[0][1]
c += 1
return ret
def get_object(self, dn, filterstr=None, attrs=None):
"""Retrieve a LDAP object specified by its DN"""
result = self.search(dn, filterstr=filterstr, scope="base", attrs=attrs)
return result[dn] if dn in result else None
def paged_search(
self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, sizelimit=None
):
"""Run a paged search on LDAP server"""
assert not self.v2, "Paged search is not available on LDAP version 2"
assert self.con or self.connect()
# Set parameters default values (if not defined)
filterstr = filterstr if filterstr else "(objectClass=*)"
attrs = attrs if attrs else []
scope = scope if scope else "sub"
pagesize = pagesize if pagesize else 500
# Initialize SimplePagedResultsControl object
page_control = SimplePagedResultsControl(
True, size=pagesize, cookie="" # Start without cookie
)
ret = {}
pages_count = 0
self.logger.debug(
"LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d"
" and attrs=%s",
basedn,
filterstr,
scope,
pagesize,
attrs,
)
while True:
pages_count += 1
self.logger.debug(
"LdapServer - Paged search: request page %d with a maximum of %d objects"
" (current total count: %d)",
pages_count,
pagesize,
len(ret),
)
try:
res_id = self.con.search_ext(
basedn, self.get_scope(scope), filterstr, attrs, serverctrls=[page_control]
)
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error running paged search on LDAP server: {e}", logging.CRITICAL
)
return False
try:
# pylint: disable=unused-variable
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id)
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error pulling paged search result from LDAP server: {e}",
logging.CRITICAL,
)
return False
# Detect and catch PagedResultsControl answer from rctrls
result_page_control = None
if rctrls:
for rctrl in rctrls:
if rctrl.controlType == SimplePagedResultsControl.controlType:
result_page_control = rctrl
break
# If PagedResultsControl answer not detected, paged search
if not result_page_control:
self._error(
"LdapServer - Server ignores RFC2696 control, paged search can not works",
logging.CRITICAL,
)
return False
# Store results of this page
for obj_dn, obj_attrs in rdata:
ret[obj_dn] = obj_attrs
# If sizelimit reached, stop
if sizelimit and len(ret) >= sizelimit:
break
# If sizelimit reached, stop
if sizelimit and len(ret) >= sizelimit:
break
# If no cookie returned, we are done
if not result_page_control.cookie:
break
# Otherwise, set cookie for the next search
page_control.cookie = result_page_control.cookie
self.logger.debug(
"LdapServer - Paged search end: %d object(s) retrieved in %d page(s) of %d object(s)",
len(ret),
pages_count,
pagesize,
)
return ret
def add_object(self, dn, attrs, encode=False):
"""Add an object in LDAP directory"""
ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs)
assert self.con or self.connect()
try:
self.logger.debug("LdapServer - Add %s", dn)
self.con.add_s(dn, ldif)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(f"LdapServer - Error adding {dn}: {e}", logging.ERROR)
return False
def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False):
"""Update an object in LDAP directory"""
assert not relax or not self.v2, "Relax modification is not available on LDAP version 2"
ldif = modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [],
)
if not ldif:
return True
assert self.con or self.connect()
try:
if relax:
self.con.modify_ext_s(dn, ldif, serverctrls=[RelaxRulesControl()])
else:
self.con.modify_s(dn, ldif)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}", logging.ERROR
)
return False
@classmethod
def update_need(cls, old, new, ignore_attrs=None, encode=False):
"""Check if an update is need on a LDAP object based on its old and new attributes values"""
ldif = cls.get_changes(old, new, ignore_attrs=ignore_attrs, encode=encode)
if not ldif:
return False
return True
@staticmethod
def get_changes(old, new, ignore_attrs=None, encode=False):
"""Retrieve changes (as modlist) on an object based on its old and new attributes values"""
return modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [],
)
@classmethod
def format_changes(cls, old, new, ignore_attrs=None, prefix=None, encode=False):
"""
Format changes (modlist) on an object based on its old and new attributes values to
display/log it
"""
return cls.format_modify_modlist(
cls.get_changes(old, new, ignore_attrs=ignore_attrs, encode=encode),
prefix=prefix,
)
@staticmethod
def format_modify_modlist(ldif, prefix=None):
"""Format modify modlist to display/log it"""
msg = []
prefix = prefix if prefix else ""
for op, attr, val in ldif:
if op == ldap.MOD_ADD: # pylint: disable=no-member
op = "ADD"
elif op == ldap.MOD_DELETE: # pylint: disable=no-member
op = "DELETE"
elif op == ldap.MOD_REPLACE: # pylint: disable=no-member
op = "REPLACE"
else:
op = f"UNKNOWN (={op})"
if val is None and op == "DELETE":
msg.append(f"{prefix} - {op} {attr}")
else:
msg.append(f"{prefix} - {op} {attr}: {val}")
return "\n".join(msg)
def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True):
"""Rename an object in LDAP directory"""
# If new_rdn is a complete DN, split new RDN and new superior DN
if len(explode_dn(new_rdn)) > 1:
self.logger.debug(
"LdapServer - Rename with a full new DN detected (%s): split new RDN and new"
" superior DN",
new_rdn,
)
assert (
new_sup is None
), "You can't provide a complete DN as new_rdn and also provide new_sup parameter"
new_dn_parts = explode_dn(new_rdn)
new_sup = ",".join(new_dn_parts[1:])
new_rdn = new_dn_parts[0]
assert self.con or self.connect()
try:
self.logger.debug(
"LdapServer - Rename %s in %s (new superior: %s, delete old: %s)",
dn,
new_rdn,
"same" if new_sup is None else new_sup,
delete_old,
)
self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error renaming {dn} in {new_rdn} "
f'(new superior: {"same" if new_sup is None else new_sup}, '
f"delete old: {delete_old}): {e}",
logging.ERROR,
)
return False
def drop_object(self, dn):
"""Drop an object in LDAP directory"""
assert self.con or self.connect()
try:
self.logger.debug("LdapServer - Delete %s", dn)
self.con.delete_s(dn)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(f"LdapServer - Error deleting {dn}: {e}", logging.ERROR)
return False
@staticmethod
def get_dn(obj):
"""Retrieve an on object DN from its entry in LDAP search result"""
return obj[0][0]
@staticmethod
def get_attr(obj, attr, all_values=None, default=None, decode=False):
"""Retrieve an on object attribute value(s) from the object entry in LDAP search result"""
if attr not in obj:
for k in obj:
if k.lower() == attr.lower():
attr = k
break
if all_values:
if attr in obj:
return decode_ldap_value(obj[attr]) if decode else obj[attr]
return default or []
if attr in obj:
return decode_ldap_value(obj[attr][0]) if decode else obj[attr][0]
return default
class LdapException(BaseException):
"""Generic LDAP exception"""
class LdapServerException(LdapException):
"""Generic exception raised by LdapServer"""
class LdapClientException(LdapException):
"""Generic exception raised by LdapClient"""
class LdapClient:
"""LDAP Client (based on python-mylib.LdapServer)"""
_options = {}
_config = None
_config_section = None
# Connection
_conn = None
# Cache objects
_cached_objects = None
def __init__(
self, options=None, options_prefix=None, config=None, config_section=None, initialize=False
):
self._options = options if options else {}
self._options_prefix = options_prefix if options_prefix else "ldap_"
self._config = config if config else None
self._config_section = config_section if config_section else "ldap"
self._cached_objects = {}
if initialize:
self.initialize()
def _get_option(self, option, default=None, required=False):
"""Retrieve option value"""
if self._options and hasattr(self._options, self._options_prefix + option):
return getattr(self._options, self._options_prefix + option)
if self._config and self._config.defined(self._config_section, option):
return self._config.get(self._config_section, option)
assert not required, f"Options {option} not defined"
return default
@property
def _just_try(self):
"""Check if just-try mode is enabled"""
return self._get_option(
"just_try", default=(self._config.get_option("just_try") if self._config else False)
)
def configure(self, comment=None, **kwargs):
"""Configure options on registered mylib.Config object"""
assert self._config, (
"mylib.Config object not registered. Must be passed to __init__ as config keyword"
" argument."
)
# Load configuration option types only here to avoid global
# dependency of ldap module with config one.
# pylint: disable=import-outside-toplevel
from mylib.config import BooleanOption, PasswordOption, StringOption
section = self._config.add_section(
self._config_section,
comment=comment if comment else "LDAP connection",
loaded_callback=self.initialize,
**kwargs,
)
section.add_option(
StringOption, "uri", default="ldap://localhost", comment="LDAP server URI"
)
section.add_option(StringOption, "binddn", comment="LDAP Bind DN")
section.add_option(
PasswordOption,
"bindpwd",
comment='LDAP Bind password (set to "keyring" to use XDG keyring)',
username_option="binddn",
keyring_value="keyring",
)
section.add_option(
BooleanOption, "checkcert", default=True, comment="Check LDAP certificate"
)
section.add_option(
BooleanOption, "disablereferral", default=False, comment="Disable referral following"
)
return section
def initialize(self, loaded_config=None):
"""Initialize LDAP connection"""
if loaded_config:
self.config = loaded_config
uri = self._get_option("uri", required=True)
binddn = self._get_option("binddn")
log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else "anonymous")
self._conn = LdapServer(
uri,
dn=binddn,
pwd=self._get_option("bindpwd"),
checkCert=self._get_option("checkcert"),
disableReferral=self._get_option("disablereferral"),
raiseOnError=True,
)
# Reset cache
self._cached_objects = {}
return self._conn.connect()
def decode(self, value):
"""Decode LDAP attribute value"""
if isinstance(value, list):
return [self.decode(v) for v in value]
if isinstance(value, str):
return value
return value.decode(
self._get_option("encoding", default=DEFAULT_ENCODING),
self._get_option("encoding_error_policy", default="ignore"),
)
def encode(self, value):
"""Encode LDAP attribute value"""
if isinstance(value, list):
return [self.encode(v) for v in value]
if isinstance(value, bytes):
return value
return value.encode(self._get_option("encoding", default=DEFAULT_ENCODING))
def _get_obj(self, dn, attrs):
"""
Build and return LDAP object as dict
:param dn: The object DN
:param attrs: The object attributes as return by python-ldap search
"""
obj = {"dn": dn}
for attr in attrs:
obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all_values=True)]
return obj
@staticmethod
def get_attr(obj, attr, default="", all_values=False):
"""
Get LDAP object attribute value(s)
:param obj: The LDAP object as returned by get_object()/get_objects
:param attr: The attribute name
:param all_values: If True, all values of the attribute will be
returned instead of the first value only
(optional, default: False)
"""
if attr not in obj:
for k in obj:
if k.lower() == attr.lower():
attr = k
break
vals = obj.get(attr, [])
if vals:
return vals if all_values else vals[0]
return default if default or not all_values else []
def get_objects(
self,
name,
filterstr,
basedn,
attrs,
key_attr=None,
warn=True,
paged_search=False,
pagesize=None,
):
"""
Retrieve objects from LDAP
:param name: The object type name
:param filterstr: The LDAP filter to use to search objects on LDAP directory
:param basedn: The base DN of the search
:param attrs: The list of attribute names to retrieve
:param key_attr: The attribute name or 'dn' to use as key in result
(optional, if leave to None, the result will be a list)
:param warn: If True, a warning message will be logged if no object is found
in LDAP directory (otherwise, it will be just a debug message)
(optional, default: True)
:param paged_search: If True, use paged search to list objects from LDAP directory
(optional, default: False)
:param pagesize: When using paged search, the page size
(optional, default: see LdapServer.paged_search)
"""
if name in self._cached_objects:
log.debug("Retrieved %s objects from cache", name)
else:
assert self._conn or self.initialize()
log.debug(
'Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn
)
if paged_search:
ldap_data = self._conn.paged_search(
basedn=basedn, filterstr=filterstr, attrs=attrs, pagesize=pagesize
)
else:
ldap_data = self._conn.search(
basedn=basedn,
filterstr=filterstr,
attrs=attrs,
)
if not ldap_data:
if warn:
log.warning("No %s found in LDAP", name)
else:
log.debug("No %s found in LDAP", name)
return {}
objects = {}
for obj_dn, obj_attrs in ldap_data.items():
# Ignore invalid result (view with an AD)
if not obj_dn or not isinstance(obj_attrs, dict):
continue
objects[obj_dn] = self._get_obj(obj_dn, obj_attrs)
self._cached_objects[name] = objects
if not key_attr or key_attr == "dn":
return self._cached_objects[name]
return {
self.get_attr(self._cached_objects[name][dn], key_attr): self._cached_objects[name][dn]
for dn in self._cached_objects[name]
}
def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True):
"""
Retrieve an object from LDAP specified using LDAP search parameters
Only one object is excepted to be returned by the LDAP search, otherwise, one
LdapClientException will be raised.
:param type_name: The object type name
:param object_name: The object name (only use in log messages)
:param filterstr: The LDAP filter to use to search the object on LDAP directory
:param basedn: The base DN of the search
:param attrs: The list of attribute names to retrieve
:param warn: If True, a warning message will be logged if no object is found
in LDAP directory (otherwise, it will be just a debug message)
(optional, default: True)
"""
assert self._conn or self.initialize()
log.debug(
'Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")',
type_name,
object_name,
filterstr,
basedn,
)
ldap_data = self._conn.search(basedn=basedn, filterstr=filterstr, attrs=attrs)
if not ldap_data:
if warn:
log.warning('No %s "%s" found in LDAP', type_name, object_name)
else:
log.debug('No %s "%s" found in LDAP', type_name, object_name)
return None
if len(ldap_data) > 1:
raise LdapClientException(
f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}'
)
dn = next(iter(ldap_data))
return self._get_obj(dn, ldap_data[dn])
def get_object_by_dn(self, type_name, dn, populate_cache_method=None, warn=True):
"""
Retrieve an LDAP object specified by its DN from cache
:param type_name: The object type name
:param dn: The object DN
:param populate_cache_method: The method to use is cache of LDAP object type
is not already populated (optional, default,
False is returned)
:param warn: If True, a warning message will be logged if object is not found
in cache (otherwise, it will be just a debug message)
(optional, default: True)
"""
if type_name not in self._cached_objects:
if not populate_cache_method:
return False
populate_cache_method()
if type_name not in self._cached_objects:
if warn:
log.warning("No %s found in LDAP", type_name)
else:
log.debug("No %s found in LDAP", type_name)
return None
if dn not in self._cached_objects[type_name]:
if warn:
log.warning('No %s found with DN "%s"', type_name, dn)
else:
log.debug('No %s found with DN "%s"', type_name, dn)
return None
return self._cached_objects[type_name][dn]
@classmethod
def object_attr_mached(cls, obj, attr, value, case_sensitive=False):
"""
Determine if object's attribute matched with specified value
:param obj: The LDAP object (as returned by get_object/get_objects)
:param attr: The attribute name
:param value: The value for the match test
:param case_sensitive: If True, the match test will be case-sensitive
(optional, default: False)
"""
if case_sensitive:
return value in cls.get_attr(obj, attr, all_values=True)
return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)]
def get_object_by_attr(
self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True
):
"""
Retrieve an LDAP object specified by one of its attribute
:param type_name: The object type name
:param attr: The attribute name
:param value: The value for the match test
:param populate_cache_method: The method to use is cache of LDAP object type
is not already populated (optional, default,
False is returned)
:param case_sensitive: If True, the match test will be case-sensitive
(optional, default: False)
:param warn: If True, a warning message will be logged if object is not found
in cache (otherwise, it will be just a debug message)
(optional, default: True)
"""
if type_name not in self._cached_objects:
if not populate_cache_method:
return False
populate_cache_method()
if type_name not in self._cached_objects:
if warn:
log.warning("No %s found in LDAP", type_name)
else:
log.debug("No %s found in LDAP", type_name)
return None
matched = {
dn: obj
for dn, obj in self._cached_objects[type_name].items()
if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive)
}
if not matched:
if warn:
log.warning('No %s found with %s="%s"', type_name, attr, value)
else:
log.debug('No %s found with %s="%s"', type_name, attr, value)
return None
if len(matched) > 1:
raise LdapClientException(
f'More than one {type_name} with {attr}="{value}" found: '
f'{" / ".join(matched.keys())}'
)
dn = next(iter(matched))
return matched[dn]
def get_changes(self, ldap_obj, attrs, protected_attrs=None):
"""
Retrieve changes on a LDAP object
:param ldap_obj: The original LDAP object
:param attrs: The new LDAP object attributes values
:param protected_attrs: Optional list of protected attributes
"""
old = {}
new = {}
protected_attrs = [a.lower() for a in protected_attrs or []]
protected_attrs.append("dn")
# New/updated attributes
for attr, values in attrs.items():
if protected_attrs and attr.lower() in protected_attrs:
continue
if attr in ldap_obj and ldap_obj[attr]:
if sorted(ldap_obj[attr]) == sorted(values):
continue
old[attr] = self.encode(ldap_obj[attr])
elif not values:
continue
new[attr] = self.encode(values)
# Deleted attributes
for attr in ldap_obj:
if (
(not protected_attrs or attr.lower() not in protected_attrs)
and ldap_obj[attr]
and not attrs.get(attr)
):
old[attr] = self.encode(ldap_obj[attr])
if old == new:
return None
return (old, new)
def format_changes(self, changes, protected_attrs=None, prefix=None):
"""
Format changes as string
:param changes: The changes as returned by get_changes
:param protected_attrs: Optional list of protected attributes
:param prefix: Optional prefix string for each line of the returned string
"""
assert self._conn or self.initialize()
return self._conn.format_changes(
changes[0], changes[1], ignore_attrs=protected_attrs, prefix=prefix
)
def update_need(self, changes, protected_attrs=None):
"""
Check if update is need
:param changes: The changes as returned by get_changes
"""
if changes is None:
return False
assert self._conn or self.initialize()
return self._conn.update_need(changes[0], changes[1], ignore_attrs=protected_attrs)
def add_object(self, dn, attrs):
"""
Add an object
:param dn: The LDAP object DN
:param attrs: The LDAP object attributes (as dict)
"""
attrs = {attr: self.encode(values) for attr, values in attrs.items() if attr != "dn"}
try:
if self._just_try:
log.debug("Just-try mode : do not really add object in LDAP")
return True
assert self._conn or self.initialize()
return self._conn.add_object(dn, attrs)
except LdapServerException:
log.error(
"An error occurred adding object %s in LDAP:\n%s\n",
dn,
pretty_format_dict(attrs),
exc_info=True,
)
return False
def update_object(self, ldap_obj, changes, protected_attrs=None, rdn_attr=None, relax=False):
"""
Update an object
:param ldap_obj: The original LDAP object
:param changes: The changes to make on LDAP object (as formatted by get_changes() method)
:param protected_attrs: An optional list of protected attributes
:param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected)
:param rdn_attr: Enable relax modification server control (optional, default: false)
"""
assert (
isinstance(changes, (list, tuple))
and len(changes) == 2
and isinstance(changes[0], dict)
and isinstance(changes[1], dict)
), f"changes parameter must be a result of get_changes() method ({type(changes)} given)"
# In case of RDN change, we need to modify passed changes, copy it to make it unchanged in
# this case
_changes = copy.deepcopy(changes)
if not rdn_attr:
rdn_attr = ldap_obj["dn"].split("=")[0]
log.debug("Auto-detected RDN attribute from DN: %s => %s", ldap_obj["dn"], rdn_attr)
old_rdn_values = self.get_attr(_changes[0], rdn_attr, all_values=True)
new_rdn_values = self.get_attr(_changes[1], rdn_attr, all_values=True)
if old_rdn_values or new_rdn_values:
if not new_rdn_values:
log.error(
"%s : Attribute %s can't be deleted because it's used as RDN.",
ldap_obj["dn"],
rdn_attr,
)
return False
log.debug(
"%s: Changes detected on %s RDN attribute: must rename object before updating it",
ldap_obj["dn"],
rdn_attr,
)
# Compute new object DN
dn_parts = explode_dn(self.decode(ldap_obj["dn"]))
basedn = ",".join(dn_parts[1:])
new_rdn = f"{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}"
new_dn = f"{new_rdn},{basedn}"
# Rename object
log.debug("%s: Rename to %s", ldap_obj["dn"], new_dn)
if not self.move_object(ldap_obj, new_dn):
return False
# Remove RDN in changes list
for attr in list(_changes[0].keys()):
if attr.lower() == rdn_attr.lower():
del _changes[0][attr]
for attr in list(_changes[1].keys()):
if attr.lower() == rdn_attr.lower():
del _changes[1][attr]
# Check that there are other changes
if not _changes[0] and not _changes[1]:
log.debug("%s: No other change after renaming", new_dn)
return True
# Otherwise, update object DN
ldap_obj["dn"] = new_dn
else:
log.debug("%s: No change detected on RDN attribute %s", ldap_obj["dn"], rdn_attr)
try:
if self._just_try:
log.debug("Just-try mode : do not really update object in LDAP")
return True
assert self._conn or self.initialize()
return self._conn.update_object(
ldap_obj["dn"], _changes[0], _changes[1], ignore_attrs=protected_attrs, relax=relax
)
except LdapServerException:
log.error(
"An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n",
ldap_obj["dn"],
pretty_format_dict(_changes[0]),
pretty_format_dict(_changes[1]),
exc_info=True,
)
return False
def move_object(self, ldap_obj, new_dn_or_rdn):
"""
Move/rename an object
:param ldap_obj: The original LDAP object
:param new_dn_or_rdn: The new LDAP object's DN (or RDN)
"""
try:
if self._just_try:
log.debug("Just-try mode : do not really move object in LDAP")
return True
assert self._conn or self.initialize()
return self._conn.rename_object(ldap_obj["dn"], new_dn_or_rdn)
except LdapServerException:
log.error(
"An error occurred moving object %s in LDAP (destination: %s)",
ldap_obj["dn"],
new_dn_or_rdn,
exc_info=True,
)
return False
def drop_object(self, ldap_obj):
"""
Drop/delete an object
:param ldap_obj: The original LDAP object to delete/drop
"""
try:
if self._just_try:
log.debug("Just-try mode : do not really drop object in LDAP")
return True
assert self._conn or self.initialize()
return self._conn.drop_object(ldap_obj["dn"])
except LdapServerException:
log.error("An error occurred removing object %s in LDAP", ldap_obj["dn"], exc_info=True)
return False
#
# LDAP date string helpers
#
def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
"""
Convert LDAP date string to datetime.datetime object
:param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date
string)
:param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : return naive datetime object (without timezone
conversion from LDAP)
"""
assert to_timezone is None or isinstance(
to_timezone, (datetime.tzinfo, str)
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
assert default_timezone is None or isinstance(
default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
), (
"default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
f" datetime.tzinfo object (not {type(default_timezone)})"
)
date = dateutil.parser.parse(value, dayfirst=False)
if not date.tzinfo:
if naive:
return date
if not default_timezone:
default_timezone = pytz.utc
elif default_timezone == "local":
default_timezone = dateutil.tz.tzlocal()
elif isinstance(default_timezone, str):
default_timezone = pytz.timezone(default_timezone)
if isinstance(default_timezone, pytz.tzinfo.DstTzInfo):
date = default_timezone.localize(date)
elif isinstance(default_timezone, datetime.tzinfo):
date = date.replace(tzinfo=default_timezone)
else:
raise LdapException("It's not supposed to happen!")
elif naive:
return date.replace(tzinfo=None)
if to_timezone:
if to_timezone == "local":
to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone)
return date.astimezone(to_timezone)
return date
def parse_date(value, to_timezone=None, default_timezone=None, naive=True):
"""
Convert LDAP date string to datetime.date object
:param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date
string)
:param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : do not handle timezone conversion from LDAP
"""
return parse_datetime(value, to_timezone, default_timezone, naive).date()
def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
"""
Convert datetime.datetime object to LDAP date string
:param value: The datetime.datetime object to convert
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : datetime store as UTC in LDAP (without
conversion)
"""
assert isinstance(
value, datetime.datetime
), f"First parameter must be an datetime.datetime object (not {type(value)})"
assert from_timezone is None or isinstance(
from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
), (
"from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
f" datetime.tzinfo object (not {type(from_timezone)})"
)
assert to_timezone is None or isinstance(
to_timezone, (datetime.tzinfo, str)
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
if not value.tzinfo and not naive:
if not from_timezone or from_timezone == "local":
from_timezone = dateutil.tz.tzlocal()
elif isinstance(from_timezone, str):
from_timezone = pytz.timezone(from_timezone)
if isinstance(from_timezone, pytz.tzinfo.DstTzInfo):
from_value = from_timezone.localize(value)
elif isinstance(from_timezone, datetime.tzinfo):
from_value = value.replace(tzinfo=from_timezone)
else:
raise LdapException("It's not supposed to happen!")
elif naive:
from_value = value.replace(tzinfo=pytz.utc)
else:
from_value = copy.deepcopy(value)
if not to_timezone:
to_timezone = pytz.utc
elif to_timezone == "local":
to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone)
to_value = from_value.astimezone(to_timezone) if not naive else from_value
datestring = to_value.strftime("%Y%m%d%H%M%S%z")
if datestring.endswith("+0000"):
datestring = datestring.replace("+0000", "Z")
return datestring
def format_date(value, from_timezone=None, to_timezone=None, naive=True):
"""
Convert datetime.date object to LDAP date string
:param value: The datetime.date object to convert
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : do not handle timezone conversion before
formatting and return datetime as UTC (because LDAP required a
timezone)
"""
assert isinstance(
value, datetime.date
), f"First parameter must be an datetime.date object (not {type(value)})"
return format_datetime(
datetime.datetime.combine(value, datetime.datetime.min.time()),
from_timezone,
to_timezone,
naive,
)