1118 lines
42 KiB
Python
1118 lines
42 KiB
Python
""" LDAP server connection helper """
|
|
|
|
import copy
|
|
import datetime
|
|
import logging
|
|
|
|
import dateutil.parser
|
|
import dateutil.tz
|
|
import ldap
|
|
import pytz
|
|
from ldap import modlist
|
|
from ldap.controls import SimplePagedResultsControl
|
|
from ldap.controls.simple import RelaxRulesControl
|
|
from ldap.dn import escape_dn_chars, explode_dn
|
|
|
|
from mylib import pretty_format_dict
|
|
|
|
log = logging.getLogger(__name__)
|
|
DEFAULT_ENCODING = "utf-8"
|
|
|
|
|
|
def decode_ldap_value(value, encoding="utf-8"):
|
|
"""Decoding LDAP attribute values helper"""
|
|
if isinstance(value, bytes):
|
|
return value.decode(encoding)
|
|
if isinstance(value, list):
|
|
return [decode_ldap_value(v) for v in value]
|
|
if isinstance(value, dict):
|
|
return {key: decode_ldap_value(values) for key, values in value.items()}
|
|
return value
|
|
|
|
|
|
def encode_ldap_value(value, encoding="utf-8"):
|
|
"""Encoding LDAP attribute values helper"""
|
|
if isinstance(value, str):
|
|
return value.encode(encoding)
|
|
if isinstance(value, list):
|
|
return [encode_ldap_value(v) for v in value]
|
|
if isinstance(value, dict):
|
|
return {key: encode_ldap_value(values) for key, values in value.items()}
|
|
return value
|
|
|
|
|
|
class LdapServer:
|
|
"""LDAP server connection helper""" # pylint: disable=useless-object-inheritance
|
|
|
|
uri = None
|
|
dn = None
|
|
pwd = None
|
|
v2 = None
|
|
|
|
con = 0
|
|
|
|
def __init__(
|
|
self,
|
|
uri,
|
|
dn=None,
|
|
pwd=None,
|
|
v2=None,
|
|
raiseOnError=False,
|
|
logger=False,
|
|
checkCert=True,
|
|
disableReferral=False,
|
|
):
|
|
self.uri = uri
|
|
self.dn = dn
|
|
self.pwd = pwd
|
|
self.raiseOnError = raiseOnError
|
|
self.checkCert = checkCert
|
|
self.disableReferral = disableReferral
|
|
if v2:
|
|
self.v2 = True
|
|
if logger:
|
|
self.logger = logger
|
|
else:
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def _error(self, error, level=logging.WARNING):
|
|
if self.raiseOnError:
|
|
raise LdapServerException(error)
|
|
self.logger.log(level, error)
|
|
|
|
def connect(self):
|
|
"""Start connection to LDAP server"""
|
|
if self.con == 0:
|
|
try:
|
|
if not self.checkCert:
|
|
# pylint: disable=no-member
|
|
ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
|
|
if self.disableReferral:
|
|
# pylint: disable=no-member
|
|
ldap.set_option(ldap.OPT_REFERRALS, ldap.OPT_OFF)
|
|
con = ldap.initialize(self.uri)
|
|
if self.v2:
|
|
con.protocol_version = ldap.VERSION2 # pylint: disable=no-member
|
|
else:
|
|
con.protocol_version = ldap.VERSION3 # pylint: disable=no-member
|
|
|
|
if self.dn:
|
|
con.simple_bind_s(self.dn, self.pwd)
|
|
elif self.uri.startswith("ldapi://"):
|
|
con.sasl_interactive_bind_s("", ldap.sasl.external())
|
|
|
|
self.con = con
|
|
return True
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(
|
|
f"LdapServer - Error connecting and binding to LDAP server: {e}",
|
|
logging.CRITICAL,
|
|
)
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_scope(scope):
|
|
"""Map scope parameter to python-ldap value"""
|
|
if scope == "base":
|
|
return ldap.SCOPE_BASE # pylint: disable=no-member
|
|
if scope == "one":
|
|
return ldap.SCOPE_ONELEVEL # pylint: disable=no-member
|
|
if scope == "sub":
|
|
return ldap.SCOPE_SUBTREE # pylint: disable=no-member
|
|
raise LdapServerException(f'Unknown LDAP scope "{scope}"')
|
|
|
|
def search(self, basedn, filterstr=None, attrs=None, sizelimit=None, scope=None):
|
|
"""Run a search on LDAP server"""
|
|
assert self.con or self.connect()
|
|
res_id = self.con.search(
|
|
basedn,
|
|
self.get_scope(scope if scope else "sub"),
|
|
filterstr if filterstr else "(objectClass=*)",
|
|
attrs if attrs else [],
|
|
)
|
|
ret = {}
|
|
c = 0
|
|
while True:
|
|
res_type, res_data = self.con.result(res_id, 0)
|
|
if res_data == [] or (sizelimit and c > sizelimit):
|
|
break
|
|
if res_type == ldap.RES_SEARCH_ENTRY: # pylint: disable=no-member
|
|
ret[res_data[0][0]] = res_data[0][1]
|
|
c += 1
|
|
return ret
|
|
|
|
def get_object(self, dn, filterstr=None, attrs=None):
|
|
"""Retrieve a LDAP object specified by its DN"""
|
|
result = self.search(dn, filterstr=filterstr, scope="base", attrs=attrs)
|
|
return result[dn] if dn in result else None
|
|
|
|
def paged_search(
|
|
self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, sizelimit=None
|
|
):
|
|
"""Run a paged search on LDAP server"""
|
|
assert not self.v2, "Paged search is not available on LDAP version 2"
|
|
assert self.con or self.connect()
|
|
|
|
# Set parameters default values (if not defined)
|
|
filterstr = filterstr if filterstr else "(objectClass=*)"
|
|
attrs = attrs if attrs else []
|
|
scope = scope if scope else "sub"
|
|
pagesize = pagesize if pagesize else 500
|
|
|
|
# Initialize SimplePagedResultsControl object
|
|
page_control = SimplePagedResultsControl(
|
|
True, size=pagesize, cookie="" # Start without cookie
|
|
)
|
|
ret = {}
|
|
pages_count = 0
|
|
self.logger.debug(
|
|
"LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d"
|
|
" and attrs=%s",
|
|
basedn,
|
|
filterstr,
|
|
scope,
|
|
pagesize,
|
|
attrs,
|
|
)
|
|
while True:
|
|
pages_count += 1
|
|
self.logger.debug(
|
|
"LdapServer - Paged search: request page %d with a maximum of %d objects"
|
|
" (current total count: %d)",
|
|
pages_count,
|
|
pagesize,
|
|
len(ret),
|
|
)
|
|
try:
|
|
res_id = self.con.search_ext(
|
|
basedn, self.get_scope(scope), filterstr, attrs, serverctrls=[page_control]
|
|
)
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(
|
|
f"LdapServer - Error running paged search on LDAP server: {e}", logging.CRITICAL
|
|
)
|
|
return False
|
|
try:
|
|
# pylint: disable=unused-variable
|
|
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id)
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(
|
|
f"LdapServer - Error pulling paged search result from LDAP server: {e}",
|
|
logging.CRITICAL,
|
|
)
|
|
return False
|
|
|
|
# Detect and catch PagedResultsControl answer from rctrls
|
|
result_page_control = None
|
|
if rctrls:
|
|
for rctrl in rctrls:
|
|
if rctrl.controlType == SimplePagedResultsControl.controlType:
|
|
result_page_control = rctrl
|
|
break
|
|
|
|
# If PagedResultsControl answer not detected, paged search
|
|
if not result_page_control:
|
|
self._error(
|
|
"LdapServer - Server ignores RFC2696 control, paged search can not works",
|
|
logging.CRITICAL,
|
|
)
|
|
return False
|
|
|
|
# Store results of this page
|
|
for obj_dn, obj_attrs in rdata:
|
|
ret[obj_dn] = obj_attrs
|
|
# If sizelimit reached, stop
|
|
if sizelimit and len(ret) >= sizelimit:
|
|
break
|
|
|
|
# If sizelimit reached, stop
|
|
if sizelimit and len(ret) >= sizelimit:
|
|
break
|
|
|
|
# If no cookie returned, we are done
|
|
if not result_page_control.cookie:
|
|
break
|
|
|
|
# Otherwise, set cookie for the next search
|
|
page_control.cookie = result_page_control.cookie
|
|
|
|
self.logger.debug(
|
|
"LdapServer - Paged search end: %d object(s) retrieved in %d page(s) of %d object(s)",
|
|
len(ret),
|
|
pages_count,
|
|
pagesize,
|
|
)
|
|
return ret
|
|
|
|
def add_object(self, dn, attrs, encode=False):
|
|
"""Add an object in LDAP directory"""
|
|
ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs)
|
|
assert self.con or self.connect()
|
|
try:
|
|
self.logger.debug("LdapServer - Add %s", dn)
|
|
self.con.add_s(dn, ldif)
|
|
return True
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(f"LdapServer - Error adding {dn}: {e}", logging.ERROR)
|
|
|
|
return False
|
|
|
|
def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False):
|
|
"""Update an object in LDAP directory"""
|
|
assert not relax or not self.v2, "Relax modification is not available on LDAP version 2"
|
|
ldif = modlist.modifyModlist(
|
|
encode_ldap_value(old) if encode else old,
|
|
encode_ldap_value(new) if encode else new,
|
|
ignore_attr_types=ignore_attrs if ignore_attrs else [],
|
|
)
|
|
if not ldif:
|
|
return True
|
|
assert self.con or self.connect()
|
|
try:
|
|
if relax:
|
|
self.con.modify_ext_s(dn, ldif, serverctrls=[RelaxRulesControl()])
|
|
else:
|
|
self.con.modify_s(dn, ldif)
|
|
return True
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(
|
|
f"LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}", logging.ERROR
|
|
)
|
|
return False
|
|
|
|
@classmethod
|
|
def update_need(cls, old, new, ignore_attrs=None, encode=False):
|
|
"""Check if an update is need on a LDAP object based on its old and new attributes values"""
|
|
ldif = cls.get_changes(old, new, ignore_attrs=ignore_attrs, encode=encode)
|
|
if not ldif:
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_changes(old, new, ignore_attrs=None, encode=False):
|
|
"""Retrieve changes (as modlist) on an object based on its old and new attributes values"""
|
|
return modlist.modifyModlist(
|
|
encode_ldap_value(old) if encode else old,
|
|
encode_ldap_value(new) if encode else new,
|
|
ignore_attr_types=ignore_attrs if ignore_attrs else [],
|
|
)
|
|
|
|
@classmethod
|
|
def format_changes(cls, old, new, ignore_attrs=None, prefix=None, encode=False):
|
|
"""
|
|
Format changes (modlist) on an object based on its old and new attributes values to
|
|
display/log it
|
|
"""
|
|
return cls.format_modify_modlist(
|
|
cls.get_changes(old, new, ignore_attrs=ignore_attrs, encode=encode),
|
|
prefix=prefix,
|
|
)
|
|
|
|
@staticmethod
|
|
def format_modify_modlist(ldif, prefix=None):
|
|
"""Format modify modlist to display/log it"""
|
|
msg = []
|
|
prefix = prefix if prefix else ""
|
|
for op, attr, val in ldif:
|
|
if op == ldap.MOD_ADD: # pylint: disable=no-member
|
|
op = "ADD"
|
|
elif op == ldap.MOD_DELETE: # pylint: disable=no-member
|
|
op = "DELETE"
|
|
elif op == ldap.MOD_REPLACE: # pylint: disable=no-member
|
|
op = "REPLACE"
|
|
else:
|
|
op = f"UNKNOWN (={op})"
|
|
if val is None and op == "DELETE":
|
|
msg.append(f"{prefix} - {op} {attr}")
|
|
else:
|
|
msg.append(f"{prefix} - {op} {attr}: {val}")
|
|
return "\n".join(msg)
|
|
|
|
def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True):
|
|
"""Rename an object in LDAP directory"""
|
|
# If new_rdn is a complete DN, split new RDN and new superior DN
|
|
if len(explode_dn(new_rdn)) > 1:
|
|
self.logger.debug(
|
|
"LdapServer - Rename with a full new DN detected (%s): split new RDN and new"
|
|
" superior DN",
|
|
new_rdn,
|
|
)
|
|
assert (
|
|
new_sup is None
|
|
), "You can't provide a complete DN as new_rdn and also provide new_sup parameter"
|
|
new_dn_parts = explode_dn(new_rdn)
|
|
new_sup = ",".join(new_dn_parts[1:])
|
|
new_rdn = new_dn_parts[0]
|
|
assert self.con or self.connect()
|
|
try:
|
|
self.logger.debug(
|
|
"LdapServer - Rename %s in %s (new superior: %s, delete old: %s)",
|
|
dn,
|
|
new_rdn,
|
|
"same" if new_sup is None else new_sup,
|
|
delete_old,
|
|
)
|
|
self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old)
|
|
return True
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(
|
|
f"LdapServer - Error renaming {dn} in {new_rdn} "
|
|
f'(new superior: {"same" if new_sup is None else new_sup}, '
|
|
f"delete old: {delete_old}): {e}",
|
|
logging.ERROR,
|
|
)
|
|
|
|
return False
|
|
|
|
def drop_object(self, dn):
|
|
"""Drop an object in LDAP directory"""
|
|
assert self.con or self.connect()
|
|
try:
|
|
self.logger.debug("LdapServer - Delete %s", dn)
|
|
self.con.delete_s(dn)
|
|
return True
|
|
except ldap.LDAPError as e: # pylint: disable=no-member
|
|
self._error(f"LdapServer - Error deleting {dn}: {e}", logging.ERROR)
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_dn(obj):
|
|
"""Retrieve an on object DN from its entry in LDAP search result"""
|
|
return obj[0][0]
|
|
|
|
@staticmethod
|
|
def get_attr(obj, attr, all_values=None, default=None, decode=False):
|
|
"""Retrieve an on object attribute value(s) from the object entry in LDAP search result"""
|
|
if attr not in obj:
|
|
for k in obj:
|
|
if k.lower() == attr.lower():
|
|
attr = k
|
|
break
|
|
if all_values:
|
|
if attr in obj:
|
|
return decode_ldap_value(obj[attr]) if decode else obj[attr]
|
|
return default or []
|
|
if attr in obj:
|
|
return decode_ldap_value(obj[attr][0]) if decode else obj[attr][0]
|
|
return default
|
|
|
|
|
|
class LdapException(BaseException):
|
|
"""Generic LDAP exception"""
|
|
|
|
|
|
class LdapServerException(LdapException):
|
|
"""Generic exception raised by LdapServer"""
|
|
|
|
|
|
class LdapClientException(LdapException):
|
|
"""Generic exception raised by LdapClient"""
|
|
|
|
|
|
class LdapClient:
|
|
|
|
"""LDAP Client (based on python-mylib.LdapServer)"""
|
|
|
|
_options = {}
|
|
_config = None
|
|
_config_section = None
|
|
|
|
# Connection
|
|
_conn = None
|
|
|
|
# Cache objects
|
|
_cached_objects = None
|
|
|
|
def __init__(
|
|
self, options=None, options_prefix=None, config=None, config_section=None, initialize=False
|
|
):
|
|
self._options = options if options else {}
|
|
self._options_prefix = options_prefix if options_prefix else "ldap_"
|
|
self._config = config if config else None
|
|
self._config_section = config_section if config_section else "ldap"
|
|
self._cached_objects = {}
|
|
if initialize:
|
|
self.initialize()
|
|
|
|
def _get_option(self, option, default=None, required=False):
|
|
"""Retrieve option value"""
|
|
if self._options and hasattr(self._options, self._options_prefix + option):
|
|
return getattr(self._options, self._options_prefix + option)
|
|
|
|
if self._config and self._config.defined(self._config_section, option):
|
|
return self._config.get(self._config_section, option)
|
|
|
|
assert not required, f"Options {option} not defined"
|
|
|
|
return default
|
|
|
|
@property
|
|
def _just_try(self):
|
|
"""Check if just-try mode is enabled"""
|
|
return self._get_option(
|
|
"just_try", default=(self._config.get_option("just_try") if self._config else False)
|
|
)
|
|
|
|
def configure(self, comment=None, **kwargs):
|
|
"""Configure options on registered mylib.Config object"""
|
|
assert self._config, (
|
|
"mylib.Config object not registered. Must be passed to __init__ as config keyword"
|
|
" argument."
|
|
)
|
|
|
|
# Load configuration option types only here to avoid global
|
|
# dependency of ldap module with config one.
|
|
# pylint: disable=import-outside-toplevel
|
|
from mylib.config import BooleanOption, PasswordOption, StringOption
|
|
|
|
section = self._config.add_section(
|
|
self._config_section,
|
|
comment=comment if comment else "LDAP connection",
|
|
loaded_callback=self.initialize,
|
|
**kwargs,
|
|
)
|
|
|
|
section.add_option(
|
|
StringOption, "uri", default="ldap://localhost", comment="LDAP server URI"
|
|
)
|
|
section.add_option(StringOption, "binddn", comment="LDAP Bind DN")
|
|
section.add_option(
|
|
PasswordOption,
|
|
"bindpwd",
|
|
comment='LDAP Bind password (set to "keyring" to use XDG keyring)',
|
|
username_option="binddn",
|
|
keyring_value="keyring",
|
|
)
|
|
section.add_option(
|
|
BooleanOption, "checkcert", default=True, comment="Check LDAP certificate"
|
|
)
|
|
section.add_option(
|
|
BooleanOption, "disablereferral", default=False, comment="Disable referral following"
|
|
)
|
|
|
|
return section
|
|
|
|
def initialize(self, loaded_config=None):
|
|
"""Initialize LDAP connection"""
|
|
if loaded_config:
|
|
self.config = loaded_config
|
|
uri = self._get_option("uri", required=True)
|
|
binddn = self._get_option("binddn")
|
|
log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else "anonymous")
|
|
self._conn = LdapServer(
|
|
uri,
|
|
dn=binddn,
|
|
pwd=self._get_option("bindpwd"),
|
|
checkCert=self._get_option("checkcert"),
|
|
disableReferral=self._get_option("disablereferral"),
|
|
raiseOnError=True,
|
|
)
|
|
# Reset cache
|
|
self._cached_objects = {}
|
|
return self._conn.connect()
|
|
|
|
def decode(self, value):
|
|
"""Decode LDAP attribute value"""
|
|
if isinstance(value, list):
|
|
return [self.decode(v) for v in value]
|
|
if isinstance(value, str):
|
|
return value
|
|
return value.decode(
|
|
self._get_option("encoding", default=DEFAULT_ENCODING),
|
|
self._get_option("encoding_error_policy", default="ignore"),
|
|
)
|
|
|
|
def encode(self, value):
|
|
"""Encode LDAP attribute value"""
|
|
if isinstance(value, list):
|
|
return [self.encode(v) for v in value]
|
|
if isinstance(value, bytes):
|
|
return value
|
|
return value.encode(self._get_option("encoding", default=DEFAULT_ENCODING))
|
|
|
|
def _get_obj(self, dn, attrs):
|
|
"""
|
|
Build and return LDAP object as dict
|
|
|
|
:param dn: The object DN
|
|
:param attrs: The object attributes as return by python-ldap search
|
|
"""
|
|
obj = {"dn": dn}
|
|
for attr in attrs:
|
|
obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all_values=True)]
|
|
return obj
|
|
|
|
@staticmethod
|
|
def get_attr(obj, attr, default="", all_values=False):
|
|
"""
|
|
Get LDAP object attribute value(s)
|
|
|
|
:param obj: The LDAP object as returned by get_object()/get_objects
|
|
:param attr: The attribute name
|
|
:param all_values: If True, all values of the attribute will be
|
|
returned instead of the first value only
|
|
(optional, default: False)
|
|
"""
|
|
if attr not in obj:
|
|
for k in obj:
|
|
if k.lower() == attr.lower():
|
|
attr = k
|
|
break
|
|
vals = obj.get(attr, [])
|
|
if vals:
|
|
return vals if all_values else vals[0]
|
|
return default if default or not all_values else []
|
|
|
|
def get_objects(
|
|
self,
|
|
name,
|
|
filterstr,
|
|
basedn,
|
|
attrs,
|
|
key_attr=None,
|
|
warn=True,
|
|
paged_search=False,
|
|
pagesize=None,
|
|
nocache=False,
|
|
):
|
|
"""
|
|
Retrieve objects from LDAP
|
|
|
|
:param name: The object type name
|
|
:param filterstr: The LDAP filter to use to search objects on LDAP directory
|
|
:param basedn: The base DN of the search
|
|
:param attrs: The list of attribute names to retrieve
|
|
:param key_attr: The attribute name or 'dn' to use as key in result
|
|
(optional, if leave to None, the result will be a list)
|
|
:param warn: If True, a warning message will be logged if no object is found
|
|
in LDAP directory (otherwise, it will be just a debug message)
|
|
(optional, default: True)
|
|
:param paged_search: If True, use paged search to list objects from LDAP directory
|
|
(optional, default: False)
|
|
:param pagesize: When using paged search, the page size
|
|
(optional, default: see LdapServer.paged_search)
|
|
:param nocache: If True, disable using cache
|
|
"""
|
|
if name in self._cached_objects and not nocache:
|
|
log.debug("Retrieved %s objects from cache", name)
|
|
objects = self._cached_objects[name]
|
|
else:
|
|
assert self._conn or self.initialize()
|
|
log.debug(
|
|
'Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn
|
|
)
|
|
if paged_search:
|
|
ldap_data = self._conn.paged_search(
|
|
basedn=basedn, filterstr=filterstr, attrs=attrs, pagesize=pagesize
|
|
)
|
|
else:
|
|
ldap_data = self._conn.search(
|
|
basedn=basedn,
|
|
filterstr=filterstr,
|
|
attrs=attrs,
|
|
)
|
|
|
|
if not ldap_data:
|
|
if warn:
|
|
log.warning("No %s found in LDAP", name)
|
|
else:
|
|
log.debug("No %s found in LDAP", name)
|
|
return {}
|
|
|
|
objects = {}
|
|
for obj_dn, obj_attrs in ldap_data.items():
|
|
# Ignore invalid result (view with an AD)
|
|
if not obj_dn or not isinstance(obj_attrs, dict):
|
|
continue
|
|
objects[obj_dn] = self._get_obj(obj_dn, obj_attrs)
|
|
if not nocache:
|
|
self._cached_objects[name] = objects
|
|
if not key_attr or key_attr == "dn":
|
|
return objects
|
|
return {self.get_attr(objects[dn], key_attr): objects[dn] for dn in objects}
|
|
|
|
def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True):
|
|
"""
|
|
Retrieve an object from LDAP specified using LDAP search parameters
|
|
|
|
Only one object is excepted to be returned by the LDAP search, otherwise, one
|
|
LdapClientException will be raised.
|
|
|
|
:param type_name: The object type name
|
|
:param object_name: The object name (only use in log messages)
|
|
:param filterstr: The LDAP filter to use to search the object on LDAP directory
|
|
:param basedn: The base DN of the search
|
|
:param attrs: The list of attribute names to retrieve
|
|
:param warn: If True, a warning message will be logged if no object is found
|
|
in LDAP directory (otherwise, it will be just a debug message)
|
|
(optional, default: True)
|
|
"""
|
|
assert self._conn or self.initialize()
|
|
log.debug(
|
|
'Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")',
|
|
type_name,
|
|
object_name,
|
|
filterstr,
|
|
basedn,
|
|
)
|
|
ldap_data = self._conn.search(basedn=basedn, filterstr=filterstr, attrs=attrs)
|
|
|
|
if not ldap_data:
|
|
if warn:
|
|
log.warning('No %s "%s" found in LDAP', type_name, object_name)
|
|
else:
|
|
log.debug('No %s "%s" found in LDAP', type_name, object_name)
|
|
return None
|
|
|
|
if len(ldap_data) > 1:
|
|
raise LdapClientException(
|
|
f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}'
|
|
)
|
|
|
|
dn = next(iter(ldap_data))
|
|
return self._get_obj(dn, ldap_data[dn])
|
|
|
|
def get_object_by_dn(self, type_name, dn, populate_cache_method=None, warn=True):
|
|
"""
|
|
Retrieve an LDAP object specified by its DN from cache
|
|
|
|
:param type_name: The object type name
|
|
:param dn: The object DN
|
|
:param populate_cache_method: The method to use is cache of LDAP object type
|
|
is not already populated (optional, default,
|
|
False is returned)
|
|
:param warn: If True, a warning message will be logged if object is not found
|
|
in cache (otherwise, it will be just a debug message)
|
|
(optional, default: True)
|
|
"""
|
|
if type_name not in self._cached_objects:
|
|
if not populate_cache_method:
|
|
return False
|
|
populate_cache_method()
|
|
if type_name not in self._cached_objects:
|
|
if warn:
|
|
log.warning("No %s found in LDAP", type_name)
|
|
else:
|
|
log.debug("No %s found in LDAP", type_name)
|
|
return None
|
|
if dn not in self._cached_objects[type_name]:
|
|
if warn:
|
|
log.warning('No %s found with DN "%s"', type_name, dn)
|
|
else:
|
|
log.debug('No %s found with DN "%s"', type_name, dn)
|
|
return None
|
|
return self._cached_objects[type_name][dn]
|
|
|
|
@classmethod
|
|
def object_attr_mached(cls, obj, attr, value, case_sensitive=False):
|
|
"""
|
|
Determine if object's attribute matched with specified value
|
|
|
|
:param obj: The LDAP object (as returned by get_object/get_objects)
|
|
:param attr: The attribute name
|
|
:param value: The value for the match test
|
|
:param case_sensitive: If True, the match test will be case-sensitive
|
|
(optional, default: False)
|
|
"""
|
|
if case_sensitive:
|
|
return value in cls.get_attr(obj, attr, all_values=True)
|
|
return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)]
|
|
|
|
def get_object_by_attr(
|
|
self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True
|
|
):
|
|
"""
|
|
Retrieve an LDAP object specified by one of its attribute
|
|
|
|
:param type_name: The object type name
|
|
:param attr: The attribute name
|
|
:param value: The value for the match test
|
|
:param populate_cache_method: The method to use is cache of LDAP object type
|
|
is not already populated (optional, default,
|
|
False is returned)
|
|
:param case_sensitive: If True, the match test will be case-sensitive
|
|
(optional, default: False)
|
|
:param warn: If True, a warning message will be logged if object is not found
|
|
in cache (otherwise, it will be just a debug message)
|
|
(optional, default: True)
|
|
"""
|
|
if type_name not in self._cached_objects:
|
|
if not populate_cache_method:
|
|
return False
|
|
populate_cache_method()
|
|
if type_name not in self._cached_objects:
|
|
if warn:
|
|
log.warning("No %s found in LDAP", type_name)
|
|
else:
|
|
log.debug("No %s found in LDAP", type_name)
|
|
return None
|
|
matched = {
|
|
dn: obj
|
|
for dn, obj in self._cached_objects[type_name].items()
|
|
if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive)
|
|
}
|
|
if not matched:
|
|
if warn:
|
|
log.warning('No %s found with %s="%s"', type_name, attr, value)
|
|
else:
|
|
log.debug('No %s found with %s="%s"', type_name, attr, value)
|
|
return None
|
|
if len(matched) > 1:
|
|
raise LdapClientException(
|
|
f'More than one {type_name} with {attr}="{value}" found: '
|
|
f'{" / ".join(matched.keys())}'
|
|
)
|
|
|
|
dn = next(iter(matched))
|
|
return matched[dn]
|
|
|
|
def get_changes(self, ldap_obj, attrs, protected_attrs=None):
|
|
"""
|
|
Retrieve changes on a LDAP object
|
|
|
|
:param ldap_obj: The original LDAP object
|
|
:param attrs: The new LDAP object attributes values
|
|
:param protected_attrs: Optional list of protected attributes
|
|
"""
|
|
old = {}
|
|
new = {}
|
|
protected_attrs = [a.lower() for a in protected_attrs or []]
|
|
protected_attrs.append("dn")
|
|
# New/updated attributes
|
|
for attr, values in attrs.items():
|
|
if protected_attrs and attr.lower() in protected_attrs:
|
|
continue
|
|
if attr in ldap_obj and ldap_obj[attr]:
|
|
if sorted(ldap_obj[attr]) == sorted(values):
|
|
continue
|
|
old[attr] = self.encode(ldap_obj[attr])
|
|
elif not values:
|
|
continue
|
|
new[attr] = self.encode(values)
|
|
|
|
# Deleted attributes
|
|
for attr in ldap_obj:
|
|
if (
|
|
(not protected_attrs or attr.lower() not in protected_attrs)
|
|
and ldap_obj[attr]
|
|
and not attrs.get(attr)
|
|
):
|
|
old[attr] = self.encode(ldap_obj[attr])
|
|
if old == new:
|
|
return None
|
|
return (old, new)
|
|
|
|
def format_changes(self, changes, protected_attrs=None, prefix=None):
|
|
"""
|
|
Format changes as string
|
|
|
|
:param changes: The changes as returned by get_changes
|
|
:param protected_attrs: Optional list of protected attributes
|
|
:param prefix: Optional prefix string for each line of the returned string
|
|
"""
|
|
assert self._conn or self.initialize()
|
|
return self._conn.format_changes(
|
|
changes[0], changes[1], ignore_attrs=protected_attrs, prefix=prefix
|
|
)
|
|
|
|
def update_need(self, changes, protected_attrs=None):
|
|
"""
|
|
Check if update is need
|
|
|
|
:param changes: The changes as returned by get_changes
|
|
"""
|
|
if changes is None:
|
|
return False
|
|
assert self._conn or self.initialize()
|
|
return self._conn.update_need(changes[0], changes[1], ignore_attrs=protected_attrs)
|
|
|
|
def add_object(self, dn, attrs):
|
|
"""
|
|
Add an object
|
|
|
|
:param dn: The LDAP object DN
|
|
:param attrs: The LDAP object attributes (as dict)
|
|
"""
|
|
attrs = {attr: self.encode(values) for attr, values in attrs.items() if attr != "dn"}
|
|
try:
|
|
if self._just_try:
|
|
log.debug("Just-try mode : do not really add object in LDAP")
|
|
return True
|
|
assert self._conn or self.initialize()
|
|
return self._conn.add_object(dn, attrs)
|
|
except LdapServerException:
|
|
log.error(
|
|
"An error occurred adding object %s in LDAP:\n%s\n",
|
|
dn,
|
|
pretty_format_dict(attrs),
|
|
exc_info=True,
|
|
)
|
|
return False
|
|
|
|
def update_object(self, ldap_obj, changes, protected_attrs=None, rdn_attr=None, relax=False):
|
|
"""
|
|
Update an object
|
|
|
|
:param ldap_obj: The original LDAP object
|
|
:param changes: The changes to make on LDAP object (as formatted by get_changes() method)
|
|
:param protected_attrs: An optional list of protected attributes
|
|
:param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected)
|
|
:param rdn_attr: Enable relax modification server control (optional, default: false)
|
|
"""
|
|
assert (
|
|
isinstance(changes, (list, tuple))
|
|
and len(changes) == 2
|
|
and isinstance(changes[0], dict)
|
|
and isinstance(changes[1], dict)
|
|
), f"changes parameter must be a result of get_changes() method ({type(changes)} given)"
|
|
# In case of RDN change, we need to modify passed changes, copy it to make it unchanged in
|
|
# this case
|
|
_changes = copy.deepcopy(changes)
|
|
if not rdn_attr:
|
|
rdn_attr = ldap_obj["dn"].split("=")[0]
|
|
log.debug("Auto-detected RDN attribute from DN: %s => %s", ldap_obj["dn"], rdn_attr)
|
|
old_rdn_values = self.get_attr(_changes[0], rdn_attr, all_values=True)
|
|
new_rdn_values = self.get_attr(_changes[1], rdn_attr, all_values=True)
|
|
if old_rdn_values or new_rdn_values:
|
|
if not new_rdn_values:
|
|
log.error(
|
|
"%s : Attribute %s can't be deleted because it's used as RDN.",
|
|
ldap_obj["dn"],
|
|
rdn_attr,
|
|
)
|
|
return False
|
|
log.debug(
|
|
"%s: Changes detected on %s RDN attribute: must rename object before updating it",
|
|
ldap_obj["dn"],
|
|
rdn_attr,
|
|
)
|
|
|
|
# Compute new object DN
|
|
dn_parts = explode_dn(self.decode(ldap_obj["dn"]))
|
|
basedn = ",".join(dn_parts[1:])
|
|
new_rdn = f"{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}"
|
|
new_dn = f"{new_rdn},{basedn}"
|
|
|
|
# Rename object
|
|
log.debug("%s: Rename to %s", ldap_obj["dn"], new_dn)
|
|
if not self.move_object(ldap_obj, new_dn):
|
|
return False
|
|
|
|
# Remove RDN in changes list
|
|
for attr in list(_changes[0].keys()):
|
|
if attr.lower() == rdn_attr.lower():
|
|
del _changes[0][attr]
|
|
for attr in list(_changes[1].keys()):
|
|
if attr.lower() == rdn_attr.lower():
|
|
del _changes[1][attr]
|
|
|
|
# Check that there are other changes
|
|
if not _changes[0] and not _changes[1]:
|
|
log.debug("%s: No other change after renaming", new_dn)
|
|
return True
|
|
|
|
# Otherwise, update object DN
|
|
ldap_obj["dn"] = new_dn
|
|
else:
|
|
log.debug("%s: No change detected on RDN attribute %s", ldap_obj["dn"], rdn_attr)
|
|
|
|
try:
|
|
if self._just_try:
|
|
log.debug("Just-try mode : do not really update object in LDAP")
|
|
return True
|
|
assert self._conn or self.initialize()
|
|
return self._conn.update_object(
|
|
ldap_obj["dn"], _changes[0], _changes[1], ignore_attrs=protected_attrs, relax=relax
|
|
)
|
|
except LdapServerException:
|
|
log.error(
|
|
"An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n",
|
|
ldap_obj["dn"],
|
|
pretty_format_dict(_changes[0]),
|
|
pretty_format_dict(_changes[1]),
|
|
exc_info=True,
|
|
)
|
|
return False
|
|
|
|
def move_object(self, ldap_obj, new_dn_or_rdn):
|
|
"""
|
|
Move/rename an object
|
|
|
|
:param ldap_obj: The original LDAP object
|
|
:param new_dn_or_rdn: The new LDAP object's DN (or RDN)
|
|
"""
|
|
try:
|
|
if self._just_try:
|
|
log.debug("Just-try mode : do not really move object in LDAP")
|
|
return True
|
|
assert self._conn or self.initialize()
|
|
return self._conn.rename_object(ldap_obj["dn"], new_dn_or_rdn)
|
|
except LdapServerException:
|
|
log.error(
|
|
"An error occurred moving object %s in LDAP (destination: %s)",
|
|
ldap_obj["dn"],
|
|
new_dn_or_rdn,
|
|
exc_info=True,
|
|
)
|
|
return False
|
|
|
|
def drop_object(self, ldap_obj):
|
|
"""
|
|
Drop/delete an object
|
|
|
|
:param ldap_obj: The original LDAP object to delete/drop
|
|
"""
|
|
try:
|
|
if self._just_try:
|
|
log.debug("Just-try mode : do not really drop object in LDAP")
|
|
return True
|
|
assert self._conn or self.initialize()
|
|
return self._conn.drop_object(ldap_obj["dn"])
|
|
except LdapServerException:
|
|
log.error("An error occurred removing object %s in LDAP", ldap_obj["dn"], exc_info=True)
|
|
return False
|
|
|
|
|
|
#
|
|
# LDAP date string helpers
|
|
#
|
|
|
|
|
|
def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
|
|
"""
|
|
Convert LDAP date string to datetime.datetime object
|
|
|
|
:param value: The LDAP date string to convert
|
|
:param to_timezone: If specified, the return datetime will be converted to this
|
|
specific timezone (optional, default : timezone of the LDAP date
|
|
string)
|
|
:param default_timezone: The timezone used if LDAP date string does not specified
|
|
the timezone (optional, default : server local timezone)
|
|
:param naive: Use naive datetime : return naive datetime object (without timezone
|
|
conversion from LDAP)
|
|
"""
|
|
assert to_timezone is None or isinstance(
|
|
to_timezone, (datetime.tzinfo, str)
|
|
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
|
|
assert default_timezone is None or isinstance(
|
|
default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
|
|
), (
|
|
"default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
|
|
f" datetime.tzinfo object (not {type(default_timezone)})"
|
|
)
|
|
date = dateutil.parser.parse(value, dayfirst=False)
|
|
if not date.tzinfo:
|
|
if naive:
|
|
return date
|
|
if not default_timezone:
|
|
default_timezone = pytz.utc
|
|
elif default_timezone == "local":
|
|
default_timezone = dateutil.tz.tzlocal()
|
|
elif isinstance(default_timezone, str):
|
|
default_timezone = pytz.timezone(default_timezone)
|
|
if isinstance(default_timezone, pytz.tzinfo.DstTzInfo):
|
|
date = default_timezone.localize(date)
|
|
elif isinstance(default_timezone, datetime.tzinfo):
|
|
date = date.replace(tzinfo=default_timezone)
|
|
else:
|
|
raise LdapException("It's not supposed to happen!")
|
|
elif naive:
|
|
return date.replace(tzinfo=None)
|
|
if to_timezone:
|
|
if to_timezone == "local":
|
|
to_timezone = dateutil.tz.tzlocal()
|
|
elif isinstance(to_timezone, str):
|
|
to_timezone = pytz.timezone(to_timezone)
|
|
return date.astimezone(to_timezone)
|
|
return date
|
|
|
|
|
|
def parse_date(value, to_timezone=None, default_timezone=None, naive=True):
|
|
"""
|
|
Convert LDAP date string to datetime.date object
|
|
|
|
:param value: The LDAP date string to convert
|
|
:param to_timezone: If specified, the return datetime will be converted to this
|
|
specific timezone (optional, default : timezone of the LDAP date
|
|
string)
|
|
:param default_timezone: The timezone used if LDAP date string does not specified
|
|
the timezone (optional, default : server local timezone)
|
|
:param naive: Use naive datetime : do not handle timezone conversion from LDAP
|
|
"""
|
|
return parse_datetime(value, to_timezone, default_timezone, naive).date()
|
|
|
|
|
|
def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
|
|
"""
|
|
Convert datetime.datetime object to LDAP date string
|
|
|
|
:param value: The datetime.datetime object to convert
|
|
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
|
|
(optional, default : server local timezone)
|
|
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
|
|
:param naive: Use naive datetime : datetime store as UTC in LDAP (without
|
|
conversion)
|
|
"""
|
|
assert isinstance(
|
|
value, datetime.datetime
|
|
), f"First parameter must be an datetime.datetime object (not {type(value)})"
|
|
assert from_timezone is None or isinstance(
|
|
from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
|
|
), (
|
|
"from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
|
|
f" datetime.tzinfo object (not {type(from_timezone)})"
|
|
)
|
|
assert to_timezone is None or isinstance(
|
|
to_timezone, (datetime.tzinfo, str)
|
|
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
|
|
if not value.tzinfo and not naive:
|
|
if not from_timezone or from_timezone == "local":
|
|
from_timezone = dateutil.tz.tzlocal()
|
|
elif isinstance(from_timezone, str):
|
|
from_timezone = pytz.timezone(from_timezone)
|
|
if isinstance(from_timezone, pytz.tzinfo.DstTzInfo):
|
|
from_value = from_timezone.localize(value)
|
|
elif isinstance(from_timezone, datetime.tzinfo):
|
|
from_value = value.replace(tzinfo=from_timezone)
|
|
else:
|
|
raise LdapException("It's not supposed to happen!")
|
|
elif naive:
|
|
from_value = value.replace(tzinfo=pytz.utc)
|
|
else:
|
|
from_value = copy.deepcopy(value)
|
|
if not to_timezone:
|
|
to_timezone = pytz.utc
|
|
elif to_timezone == "local":
|
|
to_timezone = dateutil.tz.tzlocal()
|
|
elif isinstance(to_timezone, str):
|
|
to_timezone = pytz.timezone(to_timezone)
|
|
to_value = from_value.astimezone(to_timezone) if not naive else from_value
|
|
datestring = to_value.strftime("%Y%m%d%H%M%S%z")
|
|
if datestring.endswith("+0000"):
|
|
datestring = datestring.replace("+0000", "Z")
|
|
return datestring
|
|
|
|
|
|
def format_date(value, from_timezone=None, to_timezone=None, naive=True):
|
|
"""
|
|
Convert datetime.date object to LDAP date string
|
|
|
|
:param value: The datetime.date object to convert
|
|
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
|
|
(optional, default : server local timezone)
|
|
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
|
|
:param naive: Use naive datetime : do not handle timezone conversion before
|
|
formatting and return datetime as UTC (because LDAP required a
|
|
timezone)
|
|
"""
|
|
assert isinstance(
|
|
value, datetime.date
|
|
), f"First parameter must be an datetime.date object (not {type(value)})"
|
|
return format_datetime(
|
|
datetime.datetime.combine(value, datetime.datetime.min.time()),
|
|
from_timezone,
|
|
to_timezone,
|
|
naive,
|
|
)
|