Compare commits

..

2 commits

Author SHA1 Message Date
Benjamin Renard
f541630a63 ldap: code cleaning / fix pylint/flake8 warnings
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
ci/woodpecker/tag/woodpecker Pipeline was successful
2022-06-23 18:38:21 +02:00
Benjamin Renard
2bc9964b12 ldap.LdapServer: Add encode/decode helpers and parameters 2022-06-23 18:34:48 +02:00

View file

@ -10,10 +10,10 @@ import pytz
import dateutil.parser import dateutil.parser
import dateutil.tz import dateutil.tz
import ldap import ldap
from ldap import modlist
from ldap.controls import SimplePagedResultsControl from ldap.controls import SimplePagedResultsControl
from ldap.controls.simple import RelaxRulesControl from ldap.controls.simple import RelaxRulesControl
from ldap.dn import escape_dn_chars, explode_dn from ldap.dn import escape_dn_chars, explode_dn
import ldap.modlist as modlist
from mylib import pretty_format_dict from mylib import pretty_format_dict
@ -21,6 +21,34 @@ log = logging.getLogger(__name__)
DEFAULT_ENCODING = 'utf-8' DEFAULT_ENCODING = 'utf-8'
def decode_ldap_value(value, encoding='utf-8'):
""" Decoding LDAP attribute values helper """
if isinstance(value, bytes):
return value.decode_ldap_value(encoding)
if isinstance(value, list):
return [decode_ldap_value(v) for v in value]
if isinstance(value, dict):
return dict(
(key, decode_ldap_value(values))
for key, values in value.items()
)
return value
def encode_ldap_value(value, encoding='utf-8'):
""" Encoding LDAP attribute values helper """
if isinstance(value, str):
return value.encode_ldap_value(encoding)
if isinstance(value, list):
return [encode_ldap_value(v) for v in value]
if isinstance(value, dict):
return dict(
(key, encode_ldap_value(values))
for key, values in value.items()
)
return value
class LdapServer: class LdapServer:
""" LDAP server connection helper """ # pylint: disable=useless-object-inheritance """ LDAP server connection helper """ # pylint: disable=useless-object-inheritance
@ -55,6 +83,7 @@ class LdapServer:
if self.con == 0: if self.con == 0:
try: try:
if not self.checkCert: if not self.checkCert:
# pylint: disable=no-member
ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
con = ldap.initialize(self.uri) con = ldap.initialize(self.uri)
if self.v2: if self.v2:
@ -70,7 +99,9 @@ class LdapServer:
self.con = con self.con = con
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error('LdapServer - Error connecting and binding to LDAP server : %s' % e, logging.CRITICAL) self._error(
f'LdapServer - Error connecting and binding to LDAP server: {e}',
logging.CRITICAL)
return False return False
return True return True
@ -83,7 +114,7 @@ class LdapServer:
return ldap.SCOPE_ONELEVEL # pylint: disable=no-member return ldap.SCOPE_ONELEVEL # pylint: disable=no-member
if scope == 'sub': if scope == 'sub':
return ldap.SCOPE_SUBTREE # pylint: disable=no-member return ldap.SCOPE_SUBTREE # pylint: disable=no-member
raise Exception("Unknown LDAP scope '%s'" % scope) raise Exception(f'Unknown LDAP scope "{scope}"')
def search(self, basedn, filterstr=None, attrs=None, sizelimit=0, scope=None): def search(self, basedn, filterstr=None, attrs=None, sizelimit=0, scope=None):
""" Run a search on LDAP server """ """ Run a search on LDAP server """
@ -147,12 +178,16 @@ class LdapServer:
serverctrls=[page_control] serverctrls=[page_control]
) )
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error('LdapServer - Error running paged search on LDAP server: %s' % e, logging.CRITICAL) self._error(
f'LdapServer - Error running paged search on LDAP server: {e}',
logging.CRITICAL)
return False return False
try: try:
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error('LdapServer - Error pulling paged search result from LDAP server: %s' % e, logging.CRITICAL) self._error(
f'LdapServer - Error pulling paged search result from LDAP server: {e}',
logging.CRITICAL)
return False return False
# Detect and catch PagedResultsControl answer from rctrls # Detect and catch PagedResultsControl answer from rctrls
@ -165,7 +200,9 @@ class LdapServer:
# If PagedResultsControl answer not detected, paged serach # If PagedResultsControl answer not detected, paged serach
if not result_page_control: if not result_page_control:
self._error('LdapServer - Server ignores RFC2696 control, paged search can not works', logging.CRITICAL) self._error(
'LdapServer - Server ignores RFC2696 control, paged search can not works',
logging.CRITICAL)
return False return False
# Store results of this page # Store results of this page
@ -182,27 +219,28 @@ class LdapServer:
self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize) self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize)
return ret return ret
def add_object(self, dn, attrs): def add_object(self, dn, attrs, encode=False):
""" Add an object in LDAP directory """ """ Add an object in LDAP directory """
ldif = modlist.addModlist(attrs) ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs)
assert self.con or self.connect() assert self.con or self.connect()
try: try:
self.logger.debug("LdapServer - Add %s", dn) self.logger.debug("LdapServer - Add %s", dn)
self.con.add_s(dn, ldif) self.con.add_s(dn, ldif)
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error("LdapServer - Error adding %s : %s" % (dn, e), logging.ERROR) self._error(f'LdapServer - Error adding {dn}: {e}', logging.ERROR)
return False return False
def update_object(self, dn, old, new, ignore_attrs=None, relax=False): def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False):
""" Update an object in LDAP directory """ """ Update an object in LDAP directory """
assert not relax or not self.v2, "Relax modification is not available on LDAP version 2" assert not relax or not self.v2, "Relax modification is not available on LDAP version 2"
ldif = modlist.modifyModlist( ldif = modlist.modifyModlist(
old, new, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else []
) )
if ldif == []: if not ldif:
return True return True
assert self.con or self.connect() assert self.con or self.connect()
try: try:
@ -212,34 +250,42 @@ class LdapServer:
self.con.modify_s(dn, ldif) self.con.modify_s(dn, ldif)
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error("LdapServer - Error updating %s : %s\nOld : %s\nNew : %s" % (dn, e, old, new), logging.ERROR) self._error(
f'LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}',
logging.ERROR)
return False return False
@staticmethod @staticmethod
def update_need(old, new, ignore_attrs=None): def update_need(old, new, ignore_attrs=None, encode=False):
""" Check if an update is need on a LDAP object based on its old and new attributes values """ """ Check if an update is need on a LDAP object based on its old and new attributes values """
ldif = modlist.modifyModlist( ldif = modlist.modifyModlist(
old, new, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else []
) )
if ldif == []: if not ldif:
return False return False
return True return True
@staticmethod @staticmethod
def get_changes(old, new, ignore_attrs=None): def get_changes(old, new, ignore_attrs=None, encode=False):
""" Retrieve changes (as modlist) on an object based on its old and new attributes values """ """ Retrieve changes (as modlist) on an object based on its old and new attributes values """
return modlist.modifyModlist( return modlist.modifyModlist(
old, new, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else []
) )
@staticmethod @staticmethod
def format_changes(old, new, ignore_attrs=None, prefix=None): def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False):
""" Format changes (modlist) on an object based on its old and new attributes values to display/log it """ """ Format changes (modlist) on an object based on its old and new attributes values to display/log it """
msg = [] msg = []
prefix = prefix if prefix else '' prefix = prefix if prefix else ''
for (op, attr, val) in modlist.modifyModlist(old, new, ignore_attr_types=ignore_attrs if ignore_attrs else []): for (op, attr, val) in modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
):
if op == ldap.MOD_ADD: # pylint: disable=no-member if op == ldap.MOD_ADD: # pylint: disable=no-member
op = 'ADD' op = 'ADD'
elif op == ldap.MOD_DELETE: # pylint: disable=no-member elif op == ldap.MOD_DELETE: # pylint: disable=no-member
@ -247,11 +293,11 @@ class LdapServer:
elif op == ldap.MOD_REPLACE: # pylint: disable=no-member elif op == ldap.MOD_REPLACE: # pylint: disable=no-member
op = 'REPLACE' op = 'REPLACE'
else: else:
op = 'UNKNOWN (=%s)' % op op = f'UNKNOWN (={op})'
if val is None and op == 'DELETE': if val is None and op == 'DELETE':
msg.append('%s - %s %s' % (prefix, op, attr)) msg.append(f'{prefix} - {op} {attr}')
else: else:
msg.append('%s - %s %s: %s' % (prefix, op, attr, val)) msg.append(f'{prefix} - {op} {attr}: {val}')
return '\n'.join(msg) return '\n'.join(msg)
def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True): def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True):
@ -279,13 +325,9 @@ class LdapServer:
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(
"LdapServer - Error renaming %s in %s (new superior: %s, delete old: %s): %s" % ( f'LdapServer - Error renaming {dn} in {new_rdn} '
dn, f'(new superior: {"same" if new_sup is None else new_sup}, '
new_rdn, f'delete old: {delete_old}): {e}',
"same" if new_sup is None else new_sup,
delete_old,
e
),
logging.ERROR logging.ERROR
) )
@ -299,7 +341,8 @@ class LdapServer:
self.con.delete_s(dn) self.con.delete_s(dn)
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error("LdapServer - Error deleting %s : %s" % (dn, e), logging.ERROR) self._error(
f'LdapServer - Error deleting {dn}: {e}', logging.ERROR)
return False return False
@ -309,19 +352,19 @@ class LdapServer:
return obj[0][0] return obj[0][0]
@staticmethod @staticmethod
def get_attr(obj, attr, all=None, default=None): def get_attr(obj, attr, all_values=None, default=None, decode=False):
""" Retreive an on object attribute value(s) from the object entry in LDAP search result """ """ Retreive an on object attribute value(s) from the object entry in LDAP search result """
if attr not in obj: if attr not in obj:
for k in obj: for k in obj:
if k.lower() == attr.lower(): if k.lower() == attr.lower():
attr = k attr = k
break break
if all is not None: if all_values:
if attr in obj: if attr in obj:
return obj[attr] return decode_ldap_value(obj[attr]) if decode else obj[attr]
return default or [] return default or []
if attr in obj: if attr in obj:
return obj[attr][0] return decode_ldap_value(obj[attr][0]) if decode else obj[attr][0]
return default return default
@ -370,7 +413,7 @@ class LdapClient:
if self._config and self._config.defined(self._config_section, option): if self._config and self._config.defined(self._config_section, option):
return self._config.get(self._config_section, option) return self._config.get(self._config_section, option)
assert not required, "Options %s not defined" % option assert not required, f'Options {option} not defined'
return default return default
@ -456,7 +499,7 @@ class LdapClient:
""" """
obj = dict(dn=dn) obj = dict(dn=dn)
for attr in attrs: for attr in attrs:
obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all=True)] obj[attr] = [self.decode(v) for v in self._conn.get_attr(attrs, attr, all_values=True)]
return obj return obj
@staticmethod @staticmethod
@ -554,7 +597,8 @@ class LdapClient:
return None return None
if len(ldap_data) > 1: if len(ldap_data) > 1:
raise LdapClientException('More than one %s "%s": %s' % (type_name, object_name, ' / '.join(ldap_data.keys()))) raise LdapClientException(
f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}')
dn = next(iter(ldap_data)) dn = next(iter(ldap_data))
return self._get_obj(dn, ldap_data[dn]) return self._get_obj(dn, ldap_data[dn])
@ -643,7 +687,9 @@ class LdapClient:
log.debug('No %s found with %s="%s"', type_name, attr, value) log.debug('No %s found with %s="%s"', type_name, attr, value)
return None return None
if len(matched) > 1: if len(matched) > 1:
raise LdapClientException('More than one %s with %s="%s" found: %s' % (type_name, attr, value, ' / '.join(matched.keys()))) raise LdapClientException(
f'More than one {type_name} with {attr}="{value}" found: '
f'{" / ".join(matched.keys())}')
dn = next(iter(matched)) dn = next(iter(matched))
return matched[dn] return matched[dn]
@ -658,7 +704,7 @@ class LdapClient:
""" """
old = {} old = {}
new = {} new = {}
protected_attrs = [a.lower() for a in protected_attrs or list()] protected_attrs = [a.lower() for a in protected_attrs or []]
protected_attrs.append('dn') protected_attrs.append('dn')
# New/updated attributes # New/updated attributes
for attr in attrs: for attr in attrs:
@ -740,7 +786,7 @@ class LdapClient:
:param protected_attrs: An optional list of protected attributes :param protected_attrs: An optional list of protected attributes
:param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected) :param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected)
""" """
assert isinstance(changes, (list, tuple)) and len(changes) == 2 and isinstance(changes[0], dict) and isinstance(changes[1], dict), "changes parameter must be a result of get_changes() method (%s given)" % type(changes) assert isinstance(changes, (list, tuple)) and len(changes) == 2 and isinstance(changes[0], dict) and isinstance(changes[1], dict), f'changes parameter must be a result of get_changes() method ({type(changes)} given)'
# In case of RDN change, we need to modify passed changes, copy it to make it unchanged in # In case of RDN change, we need to modify passed changes, copy it to make it unchanged in
# this case # this case
_changes = copy.deepcopy(changes) _changes = copy.deepcopy(changes)
@ -764,8 +810,8 @@ class LdapClient:
# Compute new object DN # Compute new object DN
dn_parts = explode_dn(self.decode(ldap_obj['dn'])) dn_parts = explode_dn(self.decode(ldap_obj['dn']))
basedn = ','.join(dn_parts[1:]) basedn = ','.join(dn_parts[1:])
new_rdn = '%s=%s' % (rdn_attr, escape_dn_chars(self.decode(new_rdn_values[0]))) new_rdn = f'{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}'
new_dn = '%s,%s' % (new_rdn, basedn) new_dn = f'{new_rdn},{basedn}'
# Rename object # Rename object
log.debug('%s: Rename to %s', ldap_obj['dn'], new_dn) log.debug('%s: Rename to %s', ldap_obj['dn'], new_dn)
@ -865,8 +911,8 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
the timezone (optional, default : server local timezone) the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP) :param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP)
""" """
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), 'to_timezone must be None, a datetime.tzinfo object or a string (not %s)' % type(to_timezone) assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})'
assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), 'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not %s)' % type(default_timezone) assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(default_timezone)})'
date = dateutil.parser.parse(value, dayfirst=False) date = dateutil.parser.parse(value, dayfirst=False)
if not date.tzinfo: if not date.tzinfo:
if naive: if naive:
@ -918,9 +964,9 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
:param to_timezone: The timezone used in LDAP (optional, default : UTC) :param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion) :param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion)
""" """
assert isinstance(value, datetime.datetime), 'First parameter must be an datetime.datetime object (not %s)' % type(value) assert isinstance(value, datetime.datetime), f'First parameter must be an datetime.datetime object (not {type(value)})'
assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), 'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not %s)' % type(from_timezone) assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(from_timezone)})'
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), 'to_timezone must be None, a datetime.tzinfo object or a string (not %s)' % type(to_timezone) assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})'
if not value.tzinfo and not naive: if not value.tzinfo and not naive:
if not from_timezone or from_timezone == 'local': if not from_timezone or from_timezone == 'local':
from_timezone = dateutil.tz.tzlocal() from_timezone = dateutil.tz.tzlocal()
@ -960,5 +1006,5 @@ def format_date(value, from_timezone=None, to_timezone=None, naive=True):
:param naive: Use naive datetime : do not handle timezone conversion before formating :param naive: Use naive datetime : do not handle timezone conversion before formating
and return datetime as UTC (because LDAP required a timezone) and return datetime as UTC (because LDAP required a timezone)
""" """
assert isinstance(value, datetime.date), 'First parameter must be an datetime.date object (not %s)' % type(value) assert isinstance(value, datetime.date), f'First parameter must be an datetime.date object (not {type(value)})'
return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive) return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive)