Compare commits

..

No commits in common. "62c3fadf96f8baf93b503e3e4fc61dcdff25906f" and "69d6a596a859cffe518b16562a73abbd40f258c8" have entirely different histories.

35 changed files with 2193 additions and 2558 deletions

View file

@ -1,39 +0,0 @@
# Pre-commit hooks to run tests and ensure code is cleaned.
# See https://pre-commit.com for more information
repos:
- repo: local
hooks:
- id: pytest
name: pytest
entry: python3 -m pytest tests
language: system
pass_filenames: false
always_run: true
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint --extension-pkg-whitelist=cx_Oracle
language: system
types: [python]
require_serial: true
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ['--max-line-length=100']
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: ['--keep-percent-format', '--py37-plus']
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
args: ['--target-version', 'py37', '--line-length', '100']
- repo: https://github.com/PyCQA/isort
rev: 5.11.4
hooks:
- id: isort
args: ['--profile', 'black', '--line-length', '100']

View file

@ -8,8 +8,5 @@ disable=invalid-name,
too-many-nested-blocks,
too-many-instance-attributes,
too-many-lines,
line-too-long,
duplicate-code,
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=100

174
HashMap.py Normal file
View file

@ -0,0 +1,174 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# My hash mapping library
#
# Mapping configuration
# {
# '[dst key 1]': { # Key name in the result
#
# 'order': [int], # Processing order between destinations keys
#
# # Source values
# 'other_key': [key], # Other key of the destination to use as source of values
# 'key' : '[src key]', # Key of source hash to get source values
# 'keys' : ['[sk1]', '[sk2]', ...], # List of source hash's keys to get source values
#
# # Clean / convert values
# 'cleanRegex': '[regex]', # Regex that be use to remove unwanted characters. Ex : [^0-9+]
# 'convert': [function], # Function to use to convert value : Original value will be passed
# # as argument and the value retrieve will replace source value in
# # the result
# # Ex :
# # lambda x: x.strip()
# # lambda x: "myformat : %s" % x
# # Deduplicate / check values
# 'deduplicate': [bool], # If True, sources values will be depluplicated
# 'check': [function], # Function to use to check source value : Source value will be passed
# # as argument and if function return True, the value will be preserved
# # Ex :
# # lambda x: x in my_global_hash
# # Join values
# 'join': '[glue]', # If present, sources values will be join using the "glue"
#
# # Alternative mapping
# 'or': { [map configuration] } # If this mapping case does not retreive any value, try to get value(s)
# # with this other mapping configuration
# },
# '[dst key 2]': {
# [...]
# }
# }
#
# Return format :
# {
# '[dst key 1]': ['v1','v2', ...],
# '[dst key 2]': [ ... ],
# [...]
# }
import logging, re
def clean_value(value):
if isinstance(value, int):
value=str(value)
return value.encode('utf8')
def map(map_keys,src,dst={}):
def get_values(dst_key,src,m):
# Extract sources values
values=[]
if 'other_key' in m:
if m['other_key'] in dst:
values=dst[m['other_key']]
if 'key' in m:
if m['key'] in src and src[m['key']]!='':
values.append(clean_value(src[m['key']]))
if 'keys' in m:
for key in m['keys']:
if key in src and src[key]!='':
values.append(clean_value(src[key]))
# Clean and convert values
if 'cleanRegex' in m and len(values)>0:
new_values=[]
for v in values:
nv=re.sub(m['cleanRegex'],'',v)
if nv!='':
new_values.append(nv)
values=new_values
if 'convert' in m and len(values)>0:
new_values=[]
for v in values:
nv=m['convert'](v)
if nv!='':
new_values.append(nv)
values=new_values
# Deduplicate values
if m.get('deduplicate') and len(values)>1:
new_values=[]
for v in values:
if v not in new_values:
new_values.append(v)
values=new_values
# Check values
if 'check' in m and len(values)>0:
new_values=[]
for v in values:
if m['check'](v):
new_values.append(v)
else:
logging.debug('Invalid value %s for key %s' % (v,dst_key))
if dst_key not in invalid_values:
invalid_values[dst_key]=[]
if v not in invalid_values[dst_key]:
invalid_values[dst_key].append(v)
values=new_values
# Join values
if 'join' in m and len(values)>1:
values=[m['join'].join(values)]
# Manage alternative mapping case
if len(values)==0 and 'or' in m:
values=get_values(dst_key,src,m['or'])
return values
for dst_key in sorted(map_keys.keys(), key=lambda x: map_keys[x]['order']):
values=get_values(dst_key,src,map_keys[dst_key])
if len(values)==0:
if 'required' in map_keys[dst_key] and map_keys[dst_key]['required']:
logging.debug('Destination key %s could not be filled from source but is required' % dst_key)
return False
continue
dst[dst_key]=values
return dst
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
src={
'uid': 'hmartin',
'firstname': 'Martin',
'lastname': 'Martin',
'disp_name': 'Henri Martin',
'line_1': '3 rue de Paris',
'line_2': 'Pour Pierre',
'zip_text': '92 120',
'city_text': 'Montrouge',
'line_city': '92120 Montrouge',
'tel1': '01 00 00 00 00',
'tel2': '09 00 00 00 00',
'mobile': '06 00 00 00 00',
'fax': '01 00 00 00 00',
'email': 'H.MARTIN@GMAIL.COM',
}
map_c={
'uid': {'order': 0, 'key': 'uid','required': True},
'givenName': {'order': 1, 'key': 'firstname'},
'sn': {'order': 2, 'key': 'lastname'},
'cn': {'order': 3, 'key': 'disp_name','required': True, 'or': {'attrs': ['firstname','lastname'],'join': ' '}},
'displayName': {'order': 4, 'other_key': 'displayName'},
'street': {'order': 5, 'join': ' / ', 'keys': ['ligne_1','ligne_2']},
'postalCode': {'order': 6, 'key': 'zip_text', 'cleanRegex': '[^0-9]'},
'l': {'order': 7, 'key': 'city_text'},
'postalAddress': {'order': 8, 'join': '$', 'keys': ['ligne_1','ligne_2','ligne_city']},
'telephoneNumber': {'order': 9, 'keys': ['tel1','tel2'], 'cleanRegex': '[^0-9+]', 'deduplicate': True},
'mobile': {'order': 10,'key': 'mobile'},
'facsimileTelephoneNumber': {'order': 11,'key': 'fax'},
'mail': {'order': 12,'key': 'email', 'convert': lambda x: x.lower().strip()}
}
logging.debug('[TEST] Map src=%s / config= %s' % (src,map_c))
logging.debug('[TEST] Result : %s' % map(map_c,src))

View file

@ -49,41 +49,6 @@ Just run `python setup.py install`
To know how to use these libs, you can take a look on *mylib.scripts* content or in *tests* directory.
## Code Style
[pylint](https://pypi.org/project/pylint/) is used to check for errors and enforces a coding standard, using thoses parameters:
```bash
pylint --extension-pkg-whitelist=cx_Oracle
```
[flake8](https://pypi.org/project/flake8/) is also used to check for errors and enforces a coding standard, using thoses parameters:
```bash
flake8 --max-line-length=100
```
[black](https://pypi.org/project/black/) is used to format the code, using thoses parameters:
```bash
black --target-version py37 --line-length 100
```
[isort](https://pypi.org/project/isort/) is used to format the imports, using those parameter:
```bash
isort --profile black --line-length 100
```
[pyupgrade](https://pypi.org/project/pyupgrade/) is used to automatically upgrade syntax, using those parameters:
```bash
pyupgrade --keep-percent-format --py37-plus
```
**Note:** There is `.pre-commit-config.yaml` to use [pre-commit](https://pre-commit.com/) to automatically run these tools before commits. After cloning the repository, execute `pre-commit install` to install the git hook.
## Copyright
Copyright (c) 2013-2021 Benjamin Renard <brenard@zionetrix.net>

View file

@ -10,7 +10,7 @@ def increment_prefix(prefix):
return f'{prefix if prefix else " "} '
def pretty_format_value(value, encoding="utf8", prefix=None):
def pretty_format_value(value, encoding='utf8', prefix=None):
""" Returned pretty formated value to display """
if isinstance(value, dict):
return pretty_format_dict(value, encoding=encoding, prefix=prefix)
@ -22,10 +22,10 @@ def pretty_format_value(value, encoding="utf8", prefix=None):
return f"'{value}'"
if value is None:
return "None"
return f"{value} ({type(value)})"
return f'{value} ({type(value)})'
def pretty_format_value_in_list(value, encoding="utf8", prefix=None):
def pretty_format_value_in_list(value, encoding='utf8', prefix=None):
"""
Returned pretty formated value to display in list
@ -34,31 +34,42 @@ def pretty_format_value_in_list(value, encoding="utf8", prefix=None):
"""
prefix = prefix if prefix else ""
value = pretty_format_value(value, encoding, prefix)
if "\n" in value:
if '\n' in value:
inc_prefix = increment_prefix(prefix)
value = "\n" + "\n".join([inc_prefix + line for line in value.split("\n")])
value = "\n" + "\n".join([
inc_prefix + line
for line in value.split('\n')
])
return value
def pretty_format_dict(value, encoding="utf8", prefix=None):
def pretty_format_dict(value, encoding='utf8', prefix=None):
""" Returned pretty formated dict to display """
prefix = prefix if prefix else ""
result = []
for key in sorted(value.keys()):
result.append(
f"{prefix}- {key} : "
+ pretty_format_value_in_list(value[key], encoding=encoding, prefix=prefix)
f'{prefix}- {key} : '
+ pretty_format_value_in_list(
value[key],
encoding=encoding,
prefix=prefix
)
)
return "\n".join(result)
def pretty_format_list(row, encoding="utf8", prefix=None):
def pretty_format_list(row, encoding='utf8', prefix=None):
""" Returned pretty formated list to display """
prefix = prefix if prefix else ""
result = []
for idx, values in enumerate(row):
result.append(
f"{prefix}- #{idx} : "
+ pretty_format_value_in_list(values, encoding=encoding, prefix=prefix)
f'{prefix}- #{idx} : '
+ pretty_format_value_in_list(
values,
encoding=encoding,
prefix=prefix
)
)
return "\n".join(result)

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-
""" Basic SQL DB client """
import logging
log = logging.getLogger(__name__)
@ -9,7 +12,6 @@ log = logging.getLogger(__name__)
# Exceptions
#
class DBException(Exception):
""" That is the base exception class for all the other exceptions provided by this module. """
@ -27,8 +29,7 @@ class DBNotImplemented(DBException, RuntimeError):
def __init__(self, method, class_name):
super().__init__(
"The method {method} is not yet implemented in class {class_name}",
method=method,
class_name=class_name,
method=method, class_name=class_name
)
@ -38,7 +39,10 @@ class DBFailToConnect(DBException, RuntimeError):
"""
def __init__(self, uri):
super().__init__("An error occured during database connection ({uri})", uri=uri)
super().__init__(
"An error occured during database connection ({uri})",
uri=uri
)
class DBDuplicatedSQLParameter(DBException, KeyError):
@ -49,7 +53,8 @@ class DBDuplicatedSQLParameter(DBException, KeyError):
def __init__(self, parameter_name):
super().__init__(
"Duplicated SQL parameter '{parameter_name}'", parameter_name=parameter_name
"Duplicated SQL parameter '{parameter_name}'",
parameter_name=parameter_name
)
@ -60,7 +65,10 @@ class DBUnsupportedWHEREClauses(DBException, TypeError):
"""
def __init__(self, where_clauses):
super().__init__("Unsupported WHERE clauses: {where_clauses}", where_clauses=where_clauses)
super().__init__(
"Unsupported WHERE clauses: {where_clauses}",
where_clauses=where_clauses
)
class DBInvalidOrderByClause(DBException, TypeError):
@ -71,9 +79,8 @@ class DBInvalidOrderByClause(DBException, TypeError):
def __init__(self, order_by):
super().__init__(
"Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values"
" (ordering field name and direction)",
order_by=order_by,
"Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values (ordering field name and direction)",
order_by=order_by
)
@ -86,11 +93,11 @@ class DB:
self.just_try = just_try
self._conn = None
for arg, value in kwargs.items():
setattr(self, f"_{arg}", value)
setattr(self, f'_{arg}', value)
def connect(self, exit_on_error=True):
""" Connect to DB server """
raise DBNotImplemented("connect", self.__class__.__name__)
raise DBNotImplemented('connect', self.__class__.__name__)
def close(self):
""" Close connection with DB server (if opened) """
@ -103,11 +110,12 @@ class DB:
log.debug(
'Run SQL query "%s" %s',
sql,
"with params = {}".format( # pylint: disable=consider-using-f-string
", ".join([f"{key} = {value}" for key, value in params.items()])
if params
else "without params"
),
"with params = {0}".format( # pylint: disable=consider-using-f-string
', '.join([
f'{key} = {value}'
for key, value in params.items()
]) if params else "without params"
)
)
@staticmethod
@ -115,11 +123,12 @@ class DB:
log.exception(
'Error during SQL query "%s" %s',
sql,
"with params = {}".format( # pylint: disable=consider-using-f-string
", ".join([f"{key} = {value}" for key, value in params.items()])
if params
else "without params"
),
"with params = {0}".format( # pylint: disable=consider-using-f-string
', '.join([
f'{key} = {value}'
for key, value in params.items()
]) if params else "without params"
)
)
def doSQL(self, sql, params=None):
@ -132,7 +141,7 @@ class DB:
:return: True on success, False otherwise
:rtype: bool
"""
raise DBNotImplemented("doSQL", self.__class__.__name__)
raise DBNotImplemented('doSQL', self.__class__.__name__)
def doSelect(self, sql, params=None):
"""
@ -144,7 +153,7 @@ class DB:
:return: List of selected rows as dict on success, False otherwise
:rtype: list, bool
"""
raise DBNotImplemented("doSelect", self.__class__.__name__)
raise DBNotImplemented('doSelect', self.__class__.__name__)
#
# SQL helpers
@ -153,8 +162,10 @@ class DB:
@staticmethod
def _quote_table_name(table):
""" Quote table name """
return '"{}"'.format( # pylint: disable=consider-using-f-string
'"."'.join(table.split("."))
return '"{0}"'.format( # pylint: disable=consider-using-f-string
'"."'.join(
table.split('.')
)
)
@staticmethod
@ -165,7 +176,7 @@ class DB:
@staticmethod
def format_param(param):
""" Format SQL query parameter for prepared query """
return f"%({param})s"
return f'%({param})s'
@classmethod
def _combine_params(cls, params, to_add=None, **kwargs):
@ -190,8 +201,7 @@ class DB:
- a dict of WHERE clauses with field name as key and WHERE clause value as value
- a list of any of previous valid WHERE clauses
:param params: Dict of other already set SQL query parameters (optional)
:param where_op: SQL operator used to combine WHERE clauses together (optional, default:
AND)
:param where_op: SQL operator used to combine WHERE clauses together (optional, default: AND)
:return: A tuple of two elements: raw SQL WHERE combined clauses and parameters on success
:rtype: string, bool
@ -199,27 +209,24 @@ class DB:
if params is None:
params = {}
if where_op is None:
where_op = "AND"
where_op = 'AND'
if isinstance(where_clauses, str):
return (where_clauses, params)
if (
isinstance(where_clauses, tuple)
and len(where_clauses) == 2
and isinstance(where_clauses[1], dict)
):
if isinstance(where_clauses, tuple) and len(where_clauses) == 2 and isinstance(where_clauses[1], dict):
cls._combine_params(params, where_clauses[1])
return (where_clauses[0], params)
if isinstance(where_clauses, (list, tuple)):
sql_where_clauses = []
for where_clause in where_clauses:
sql2, params = cls._format_where_clauses(
where_clause, params=params, where_op=where_op
)
sql2, params = cls._format_where_clauses(where_clause, params=params, where_op=where_op)
sql_where_clauses.append(sql2)
return (f" {where_op} ".join(sql_where_clauses), params)
return (
f' {where_op} '.join(sql_where_clauses),
params
)
if isinstance(where_clauses, dict):
sql_where_clauses = []
@ -228,13 +235,16 @@ class DB:
if field in params:
idx = 1
while param in params:
param = f"{field}_{idx}"
param = f'{field}_{idx}'
idx += 1
cls._combine_params(params, {param: value})
sql_where_clauses.append(
f"{cls._quote_field_name(field)} = {cls.format_param(param)}"
f'{cls._quote_field_name(field)} = {cls.format_param(param)}'
)
return (
f' {where_op} '.join(sql_where_clauses),
params
)
return (f" {where_op} ".join(sql_where_clauses), params)
raise DBUnsupportedWHEREClauses(where_clauses)
@classmethod
@ -245,26 +255,29 @@ class DB:
:param sql: The SQL query to complete
:param params: The dict of parameters of the SQL query to complete
:param where_clauses: The WHERE clause (see _format_where_clauses())
:param where_op: SQL operator used to combine WHERE clauses together (optional, default:
see _format_where_clauses())
:param where_op: SQL operator used to combine WHERE clauses together (optional, default: see _format_where_clauses())
:return:
:rtype: A tuple of two elements: raw SQL WHERE combined clauses and parameters
"""
if where_clauses:
sql_where, params = cls._format_where_clauses(
where_clauses, params=params, where_op=where_op
)
sql_where, params = cls._format_where_clauses(where_clauses, params=params, where_op=where_op)
sql += " WHERE " + sql_where
return (sql, params)
def insert(self, table, values, just_try=False):
""" Run INSERT SQL query """
# pylint: disable=consider-using-f-string
sql = "INSERT INTO {} ({}) VALUES ({})".format(
sql = 'INSERT INTO {0} ({1}) VALUES ({2})'.format(
self._quote_table_name(table),
", ".join([self._quote_field_name(field) for field in values.keys()]),
", ".join([self.format_param(key) for key in values]),
', '.join([
self._quote_field_name(field)
for field in values.keys()
]),
", ".join([
self.format_param(key)
for key in values
])
)
if just_try:
@ -280,18 +293,19 @@ class DB:
def update(self, table, values, where_clauses, where_op=None, just_try=False):
""" Run UPDATE SQL query """
# pylint: disable=consider-using-f-string
sql = "UPDATE {} SET {}".format(
sql = 'UPDATE {0} SET {1}'.format(
self._quote_table_name(table),
", ".join(
[f"{self._quote_field_name(key)} = {self.format_param(key)}" for key in values]
),
", ".join([
f'{self._quote_field_name(key)} = {self.format_param(key)}'
for key in values
])
)
params = values
try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error("Fail to add WHERE clauses", exc_info=True)
log.error('Fail to add WHERE clauses', exc_info=True)
return False
if just_try:
@ -304,15 +318,15 @@ class DB:
return False
return True
def delete(self, table, where_clauses, where_op="AND", just_try=False):
def delete(self, table, where_clauses, where_op='AND', just_try=False):
""" Run DELETE SQL query """
sql = f"DELETE FROM {self._quote_table_name(table)}"
sql = f'DELETE FROM {self._quote_table_name(table)}'
params = {}
try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error("Fail to add WHERE clauses", exc_info=True)
log.error('Fail to add WHERE clauses', exc_info=True)
return False
if just_try:
@ -327,7 +341,7 @@ class DB:
def truncate(self, table, just_try=False):
""" Run TRUNCATE SQL query """
sql = f"TRUNCATE TABLE {self._quote_table_name(table)}"
sql = f'TRUNCATE TABLE {self._quote_table_name(table)}'
if just_try:
log.debug("Just-try mode: execute TRUNCATE query: %s", sql)
@ -339,36 +353,33 @@ class DB:
return False
return True
def select(
self, table, where_clauses=None, fields=None, where_op="AND", order_by=None, just_try=False
):
def select(self, table, where_clauses=None, fields=None, where_op='AND', order_by=None, just_try=False):
""" Run SELECT SQL query """
sql = "SELECT "
if fields is None:
sql += "*"
elif isinstance(fields, str):
sql += f"{self._quote_field_name(fields)}"
sql += f'{self._quote_field_name(fields)}'
else:
sql += ", ".join([self._quote_field_name(field) for field in fields])
sql += ', '.join([self._quote_field_name(field) for field in fields])
sql += f" FROM {self._quote_table_name(table)}"
sql += f' FROM {self._quote_table_name(table)}'
params = {}
try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error("Fail to add WHERE clauses", exc_info=True)
log.error('Fail to add WHERE clauses', exc_info=True)
return False
if order_by:
if isinstance(order_by, str):
sql += f" ORDER BY {order_by}"
sql += f' ORDER BY {order_by}'
elif (
isinstance(order_by, (list, tuple))
and len(order_by) == 2
isinstance(order_by, (list, tuple)) and len(order_by) == 2
and isinstance(order_by[0], str)
and isinstance(order_by[1], str)
and order_by[1].upper() in ("ASC", "UPPER")
and order_by[1].upper() in ('ASC', 'UPPER')
):
sql += f' ORDER BY "{order_by[0]}" {order_by[1].upper()}'
else:

View file

@ -1,51 +1,49 @@
# -*- coding: utf-8 -*-
""" Email client to forge and send emails """
import email.utils
import logging
import os
import smtplib
from email.encoders import encode_base64
from email.mime.base import MIMEBase
from email.mime.multipart import MIMEMultipart
import email.utils
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email.encoders import encode_base64
from mako.template import Template as MakoTemplate
from mylib.config import (
BooleanOption,
ConfigurableObject,
IntegerOption,
PasswordOption,
StringOption,
)
from mylib.config import ConfigurableObject
from mylib.config import BooleanOption
from mylib.config import IntegerOption
from mylib.config import PasswordOption
from mylib.config import StringOption
log = logging.getLogger(__name__)
class EmailClient(
ConfigurableObject
): # pylint: disable=useless-object-inheritance,too-many-instance-attributes
class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inheritance,too-many-instance-attributes
"""
Email client
This class abstract all interactions with the SMTP server.
"""
_config_name = "email"
_config_comment = "Email"
_config_name = 'email'
_config_comment = 'Email'
_defaults = {
"smtp_host": "localhost",
"smtp_port": 25,
"smtp_ssl": False,
"smtp_tls": False,
"smtp_user": None,
"smtp_password": None,
"smtp_debug": False,
"sender_name": "No reply",
"sender_email": "noreply@localhost",
"encoding": "utf-8",
"catch_all_addr": None,
"just_try": False,
'smtp_host': 'localhost',
'smtp_port': 25,
'smtp_ssl': False,
'smtp_tls': False,
'smtp_user': None,
'smtp_password': None,
'smtp_debug': False,
'sender_name': 'No reply',
'sender_email': 'noreply@localhost',
'encoding': 'utf-8',
'catch_all_addr': None,
'just_try': False,
}
templates = {}
@ -63,101 +61,55 @@ class EmailClient(
if use_smtp:
section.add_option(
StringOption,
"smtp_host",
default=self._defaults["smtp_host"],
comment="SMTP server hostname/IP address",
)
StringOption, 'smtp_host', default=self._defaults['smtp_host'],
comment='SMTP server hostname/IP address')
section.add_option(
IntegerOption,
"smtp_port",
default=self._defaults["smtp_port"],
comment="SMTP server port",
)
IntegerOption, 'smtp_port', default=self._defaults['smtp_port'],
comment='SMTP server port')
section.add_option(
BooleanOption,
"smtp_ssl",
default=self._defaults["smtp_ssl"],
comment="Use SSL on SMTP server connection",
)
BooleanOption, 'smtp_ssl', default=self._defaults['smtp_ssl'],
comment='Use SSL on SMTP server connection')
section.add_option(
BooleanOption,
"smtp_tls",
default=self._defaults["smtp_tls"],
comment="Use TLS on SMTP server connection",
)
BooleanOption, 'smtp_tls', default=self._defaults['smtp_tls'],
comment='Use TLS on SMTP server connection')
section.add_option(
StringOption,
"smtp_user",
default=self._defaults["smtp_user"],
comment="SMTP authentication username",
)
StringOption, 'smtp_user', default=self._defaults['smtp_user'],
comment='SMTP authentication username')
section.add_option(
PasswordOption,
"smtp_password",
default=self._defaults["smtp_password"],
PasswordOption, 'smtp_password', default=self._defaults['smtp_password'],
comment='SMTP authentication password (set to "keyring" to use XDG keyring)',
username_option="smtp_user",
keyring_value="keyring",
)
username_option='smtp_user', keyring_value='keyring')
section.add_option(
BooleanOption,
"smtp_debug",
default=self._defaults["smtp_debug"],
comment="Enable SMTP debugging",
)
BooleanOption, 'smtp_debug', default=self._defaults['smtp_debug'],
comment='Enable SMTP debugging')
section.add_option(
StringOption,
"sender_name",
default=self._defaults["sender_name"],
comment="Sender name",
)
StringOption, 'sender_name', default=self._defaults['sender_name'],
comment='Sender name')
section.add_option(
StringOption,
"sender_email",
default=self._defaults["sender_email"],
comment="Sender email address",
)
StringOption, 'sender_email', default=self._defaults['sender_email'],
comment='Sender email address')
section.add_option(
StringOption, "encoding", default=self._defaults["encoding"], comment="Email encoding"
)
StringOption, 'encoding', default=self._defaults['encoding'],
comment='Email encoding')
section.add_option(
StringOption,
"catch_all_addr",
default=self._defaults["catch_all_addr"],
comment="Catch all sent emails to this specified email address",
)
StringOption, 'catch_all_addr', default=self._defaults['catch_all_addr'],
comment='Catch all sent emails to this specified email address')
if just_try:
section.add_option(
BooleanOption,
"just_try",
default=self._defaults["just_try"],
comment="Just-try mode: do not really send emails",
)
BooleanOption, 'just_try', default=self._defaults['just_try'],
comment='Just-try mode: do not really send emails')
return section
def forge_message(
self,
rcpt_to,
subject=None,
html_body=None,
text_body=None, # pylint: disable=too-many-arguments,too-many-locals
attachment_files=None,
attachment_payloads=None,
sender_name=None,
sender_email=None,
encoding=None,
template=None,
**template_vars,
):
def forge_message(self, rcpt_to, subject=None, html_body=None, text_body=None, # pylint: disable=too-many-arguments,too-many-locals
attachment_files=None, attachment_payloads=None, sender_name=None,
sender_email=None, encoding=None, template=None, **template_vars):
"""
Forge a message
:param rcpt_to: The recipient of the email. Could be a tuple(name, email) or
just the email of the recipient.
:param rcpt_to: The recipient of the email. Could be a tuple(name, email) or just the email of the recipient.
:param subject: The subject of the email.
:param html_body: The HTML body of the email
:param text_body: The plain text body of the email
@ -170,69 +122,64 @@ class EmailClient(
All other parameters will be consider as template variables.
"""
msg = MIMEMultipart("alternative")
msg["To"] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to
msg["From"] = email.utils.formataddr(
msg = MIMEMultipart('alternative')
msg['To'] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to
msg['From'] = email.utils.formataddr(
(
sender_name or self._get_option("sender_name"),
sender_email or self._get_option("sender_email"),
sender_name or self._get_option('sender_name'),
sender_email or self._get_option('sender_email')
)
)
if subject:
msg["Subject"] = subject.format(**template_vars)
msg["Date"] = email.utils.formatdate(None, True)
encoding = encoding if encoding else self._get_option("encoding")
msg['Subject'] = subject.format(**template_vars)
msg['Date'] = email.utils.formatdate(None, True)
encoding = encoding if encoding else self._get_option('encoding')
if template:
assert template in self.templates, f"Unknwon template {template}"
assert template in self.templates, f'Unknwon template {template}'
# Handle subject from template
if not subject:
assert self.templates[template].get(
"subject"
), f"No subject defined in template {template}"
msg["Subject"] = self.templates[template]["subject"].format(**template_vars)
assert self.templates[template].get('subject'), f'No subject defined in template {template}'
msg['Subject'] = self.templates[template]['subject'].format(**template_vars)
# Put HTML part in last one to prefered it
parts = []
if self.templates[template].get("text"):
if isinstance(self.templates[template]["text"], MakoTemplate):
parts.append(
(self.templates[template]["text"].render(**template_vars), "plain")
)
if self.templates[template].get('text'):
if isinstance(self.templates[template]['text'], MakoTemplate):
parts.append((self.templates[template]['text'].render(**template_vars), 'plain'))
else:
parts.append(
(self.templates[template]["text"].format(**template_vars), "plain")
)
if self.templates[template].get("html"):
if isinstance(self.templates[template]["html"], MakoTemplate):
parts.append((self.templates[template]["html"].render(**template_vars), "html"))
parts.append((self.templates[template]['text'].format(**template_vars), 'plain'))
if self.templates[template].get('html'):
if isinstance(self.templates[template]['html'], MakoTemplate):
parts.append((self.templates[template]['html'].render(**template_vars), 'html'))
else:
parts.append((self.templates[template]["html"].format(**template_vars), "html"))
parts.append((self.templates[template]['html'].format(**template_vars), 'html'))
for body, mime_type in parts:
msg.attach(MIMEText(body.encode(encoding), mime_type, _charset=encoding))
else:
assert subject, "No subject provided"
assert subject, 'No subject provided'
if text_body:
msg.attach(MIMEText(text_body.encode(encoding), "plain", _charset=encoding))
msg.attach(MIMEText(text_body.encode(encoding), 'plain', _charset=encoding))
if html_body:
msg.attach(MIMEText(html_body.encode(encoding), "html", _charset=encoding))
msg.attach(MIMEText(html_body.encode(encoding), 'html', _charset=encoding))
if attachment_files:
for filepath in attachment_files:
with open(filepath, "rb") as fp:
part = MIMEBase("application", "octet-stream")
with open(filepath, 'rb') as fp:
part = MIMEBase('application', "octet-stream")
part.set_payload(fp.read())
encode_base64(part)
part.add_header(
"Content-Disposition",
f'attachment; filename="{os.path.basename(filepath)}"',
)
'Content-Disposition',
f'attachment; filename="{os.path.basename(filepath)}"')
msg.attach(part)
if attachment_payloads:
for filename, payload in attachment_payloads:
part = MIMEBase("application", "octet-stream")
part = MIMEBase('application', "octet-stream")
part.set_payload(payload)
encode_base64(part)
part.add_header("Content-Disposition", f'attachment; filename="{filename}"')
part.add_header(
'Content-Disposition',
f'attachment; filename="{filename}"')
msg.attach(part)
return msg
@ -245,184 +192,200 @@ class EmailClient(
:param msg: The message of this email (as MIMEBase or derivated classes)
:param subject: The subject of the email (only if the message is not provided
using msg parameter)
:param just_try: Enable just try mode (do not really send email, default: as defined on
initialization)
:param just_try: Enable just try mode (do not really send email, default: as defined on initialization)
All other parameters will be consider as parameters to forge the message
(only if the message is not provided using msg parameter).
"""
msg = msg if msg else self.forge_message(rcpt_to, subject, **forge_args)
if just_try or self._get_option("just_try"):
log.debug(
'Just-try mode: do not really send this email to %s (subject="%s")',
rcpt_to,
subject or msg.get("subject", "No subject"),
)
if just_try or self._get_option('just_try'):
log.debug('Just-try mode: do not really send this email to %s (subject="%s")', rcpt_to, subject or msg.get('subject', 'No subject'))
return True
catch_addr = self._get_option("catch_all_addr")
catch_addr = self._get_option('catch_all_addr')
if catch_addr:
log.debug("Catch email originaly send to %s to %s", rcpt_to, catch_addr)
log.debug('Catch email originaly send to %s to %s', rcpt_to, catch_addr)
rcpt_to = catch_addr
smtp_host = self._get_option("smtp_host")
smtp_port = self._get_option("smtp_port")
smtp_host = self._get_option('smtp_host')
smtp_port = self._get_option('smtp_port')
try:
if self._get_option("smtp_ssl"):
if self._get_option('smtp_ssl'):
logging.info("Establish SSL connection to server %s:%s", smtp_host, smtp_port)
server = smtplib.SMTP_SSL(smtp_host, smtp_port)
else:
logging.info("Establish connection to server %s:%s", smtp_host, smtp_port)
server = smtplib.SMTP(smtp_host, smtp_port)
if self._get_option("smtp_tls"):
logging.info("Start TLS on SMTP connection")
if self._get_option('smtp_tls'):
logging.info('Start TLS on SMTP connection')
server.starttls()
except smtplib.SMTPException:
log.error("Error connecting to SMTP server %s:%s", smtp_host, smtp_port, exc_info=True)
log.error('Error connecting to SMTP server %s:%s', smtp_host, smtp_port, exc_info=True)
return False
if self._get_option("smtp_debug"):
if self._get_option('smtp_debug'):
server.set_debuglevel(True)
smtp_user = self._get_option("smtp_user")
smtp_password = self._get_option("smtp_password")
smtp_user = self._get_option('smtp_user')
smtp_password = self._get_option('smtp_password')
if smtp_user and smtp_password:
try:
log.info("Try to authenticate on SMTP connection as %s", smtp_user)
log.info('Try to authenticate on SMTP connection as %s', smtp_user)
server.login(smtp_user, smtp_password)
except smtplib.SMTPException:
log.error(
"Error authenticating on SMTP server %s:%s with user %s",
smtp_host,
smtp_port,
smtp_user,
exc_info=True,
)
'Error authenticating on SMTP server %s:%s with user %s',
smtp_host, smtp_port, smtp_user, exc_info=True)
return False
error = False
try:
log.info("Sending email to %s", rcpt_to)
log.info('Sending email to %s', rcpt_to)
server.sendmail(
self._get_option("sender_email"),
self._get_option('sender_email'),
[rcpt_to[1] if isinstance(rcpt_to, tuple) else rcpt_to],
msg.as_string(),
msg.as_string()
)
except smtplib.SMTPException:
error = True
log.error("Error sending email to %s", rcpt_to, exc_info=True)
log.error('Error sending email to %s', rcpt_to, exc_info=True)
finally:
server.quit()
return not error
if __name__ == "__main__":
if __name__ == '__main__':
# Run tests
import argparse
import datetime
import sys
import argparse
# Options parser
parser = argparse.ArgumentParser()
parser.add_argument(
"-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode"
'-v', '--verbose',
action="store_true",
dest="verbose",
help="Enable verbose mode"
)
parser.add_argument(
"-d", "--debug", action="store_true", dest="debug", help="Enable debug mode"
'-d', '--debug',
action="store_true",
dest="debug",
help="Enable debug mode"
)
parser.add_argument(
"-l", "--log-file", action="store", type=str, dest="logfile", help="Log file path"
'-l', '--log-file',
action="store",
type=str,
dest="logfile",
help="Log file path"
)
parser.add_argument(
"-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode"
'-j', '--just-try',
action="store_true",
dest="just_try",
help="Enable just-try mode"
)
email_opts = parser.add_argument_group("Email options")
email_opts = parser.add_argument_group('Email options')
email_opts.add_argument(
"-H", "--smtp-host", action="store", type=str, dest="email_smtp_host", help="SMTP host"
'-H', '--smtp-host',
action="store",
type=str,
dest="email_smtp_host",
help="SMTP host"
)
email_opts.add_argument(
"-P", "--smtp-port", action="store", type=int, dest="email_smtp_port", help="SMTP port"
'-P', '--smtp-port',
action="store",
type=int,
dest="email_smtp_port",
help="SMTP port"
)
email_opts.add_argument(
"-S", "--smtp-ssl", action="store_true", dest="email_smtp_ssl", help="Use SSL"
'-S', '--smtp-ssl',
action="store_true",
dest="email_smtp_ssl",
help="Use SSL"
)
email_opts.add_argument(
"-T", "--smtp-tls", action="store_true", dest="email_smtp_tls", help="Use TLS"
'-T', '--smtp-tls',
action="store_true",
dest="email_smtp_tls",
help="Use TLS"
)
email_opts.add_argument(
"-u", "--smtp-user", action="store", type=str, dest="email_smtp_user", help="SMTP username"
'-u', '--smtp-user',
action="store",
type=str,
dest="email_smtp_user",
help="SMTP username"
)
email_opts.add_argument(
"-p",
"--smtp-password",
'-p', '--smtp-password',
action="store",
type=str,
dest="email_smtp_password",
help="SMTP password",
help="SMTP password"
)
email_opts.add_argument(
"-D",
"--smtp-debug",
'-D', '--smtp-debug',
action="store_true",
dest="email_smtp_debug",
help="Debug SMTP connection",
help="Debug SMTP connection"
)
email_opts.add_argument(
"-e",
"--email-encoding",
'-e', '--email-encoding',
action="store",
type=str,
dest="email_encoding",
help="SMTP encoding",
help="SMTP encoding"
)
email_opts.add_argument(
"-f",
"--sender-name",
'-f', '--sender-name',
action="store",
type=str,
dest="email_sender_name",
help="Sender name",
help="Sender name"
)
email_opts.add_argument(
"-F",
"--sender-email",
'-F', '--sender-email',
action="store",
type=str,
dest="email_sender_email",
help="Sender email",
help="Sender email"
)
email_opts.add_argument(
"-C",
"--catch-all",
'-C', '--catch-all',
action="store",
type=str,
dest="email_catch_all",
help="Catch all sent email: specify catch recipient email address",
help="Catch all sent email: specify catch recipient email address"
)
test_opts = parser.add_argument_group("Test email options")
test_opts = parser.add_argument_group('Test email options')
test_opts.add_argument(
"-t",
"--to",
'-t', '--to',
action="store",
type=str,
dest="test_to",
@ -430,8 +393,7 @@ if __name__ == "__main__":
)
test_opts.add_argument(
"-m",
"--mako",
'-m', '--mako',
action="store_true",
dest="test_mako",
help="Test mako templating",
@ -440,11 +402,11 @@ if __name__ == "__main__":
options = parser.parse_args()
if not options.test_to:
parser.error("You must specify test email recipient using -t/--to parameter")
parser.error('You must specify test email recipient using -t/--to parameter')
sys.exit(1)
# Initialize logs
logformat = "%(asctime)s - Test EmailClient - %(levelname)s - %(message)s"
logformat = '%(asctime)s - Test EmailClient - %(levelname)s - %(message)s'
if options.debug:
loglevel = logging.DEBUG
elif options.verbose:
@ -459,10 +421,9 @@ if __name__ == "__main__":
if options.email_smtp_user and not options.email_smtp_password:
import getpass
options.email_smtp_password = getpass.getpass('Please enter SMTP password: ')
options.email_smtp_password = getpass.getpass("Please enter SMTP password: ")
logging.info("Initialize Email client")
logging.info('Initialize Email client')
email_client = EmailClient(
smtp_host=options.email_smtp_host,
smtp_port=options.email_smtp_port,
@ -480,24 +441,20 @@ if __name__ == "__main__":
test=dict(
subject="Test email",
text=(
"Just a test email sent at {sent_date}."
if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
"Just a test email sent at {sent_date}." if not options.test_mako else
MakoTemplate("Just a test email sent at ${sent_date}.")
),
html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
if not options.test_mako
else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>")
)
)
),
)
),
)
logging.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
logging.info("Test email sent")
logging.info('Send a test email to %s', options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()):
logging.info('Test email sent')
sys.exit(0)
logging.error("Fail to send test email")
logging.error('Fail to send test email')
sys.exit(1)

View file

@ -1,13 +1,15 @@
# -*- coding: utf-8 -*-
""" LDAP server connection helper """
import copy
import datetime
import logging
import pytz
import dateutil.parser
import dateutil.tz
import ldap
import pytz
from ldap import modlist
from ldap.controls import SimplePagedResultsControl
from ldap.controls.simple import RelaxRulesControl
@ -16,28 +18,34 @@ from ldap.dn import escape_dn_chars, explode_dn
from mylib import pretty_format_dict
log = logging.getLogger(__name__)
DEFAULT_ENCODING = "utf-8"
DEFAULT_ENCODING = 'utf-8'
def decode_ldap_value(value, encoding="utf-8"):
def decode_ldap_value(value, encoding='utf-8'):
""" Decoding LDAP attribute values helper """
if isinstance(value, bytes):
return value.decode(encoding)
if isinstance(value, list):
return [decode_ldap_value(v) for v in value]
if isinstance(value, dict):
return {key: decode_ldap_value(values) for key, values in value.items()}
return dict(
(key, decode_ldap_value(values))
for key, values in value.items()
)
return value
def encode_ldap_value(value, encoding="utf-8"):
def encode_ldap_value(value, encoding='utf-8'):
""" Encoding LDAP attribute values helper """
if isinstance(value, str):
return value.encode(encoding)
if isinstance(value, list):
return [encode_ldap_value(v) for v in value]
if isinstance(value, dict):
return {key: encode_ldap_value(values) for key, values in value.items()}
return dict(
(key, encode_ldap_value(values))
for key, values in value.items()
)
return value
@ -51,17 +59,9 @@ class LdapServer:
con = 0
def __init__(
self,
uri,
dn=None,
pwd=None,
v2=None,
raiseOnError=False,
logger=False,
checkCert=True,
disableReferral=False,
):
def __init__(self, uri, dn=None, pwd=None, v2=None,
raiseOnError=False, logger=False, checkCert=True,
disableReferral=False):
self.uri = uri
self.dn = dn
self.pwd = pwd
@ -98,27 +98,26 @@ class LdapServer:
if self.dn:
con.simple_bind_s(self.dn, self.pwd)
elif self.uri.startswith("ldapi://"):
elif self.uri.startswith('ldapi://'):
con.sasl_interactive_bind_s("", ldap.sasl.external())
self.con = con
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error connecting and binding to LDAP server: {e}",
logging.CRITICAL,
)
f'LdapServer - Error connecting and binding to LDAP server: {e}',
logging.CRITICAL)
return False
return True
@staticmethod
def get_scope(scope):
""" Map scope parameter to python-ldap value """
if scope == "base":
if scope == 'base':
return ldap.SCOPE_BASE # pylint: disable=no-member
if scope == "one":
if scope == 'one':
return ldap.SCOPE_ONELEVEL # pylint: disable=no-member
if scope == "sub":
if scope == 'sub':
return ldap.SCOPE_SUBTREE # pylint: disable=no-member
raise Exception(f'Unknown LDAP scope "{scope}"')
@ -127,9 +126,9 @@ class LdapServer:
assert self.con or self.connect()
res_id = self.con.search(
basedn,
self.get_scope(scope if scope else "sub"),
filterstr if filterstr else "(objectClass=*)",
attrs if attrs else [],
self.get_scope(scope if scope else 'sub'),
filterstr if filterstr else '(objectClass=*)',
attrs if attrs else []
)
ret = {}
c = 0
@ -144,63 +143,64 @@ class LdapServer:
def get_object(self, dn, filterstr=None, attrs=None):
""" Retrieve a LDAP object specified by its DN """
result = self.search(dn, filterstr=filterstr, scope="base", attrs=attrs)
result = self.search(dn, filterstr=filterstr, scope='base', attrs=attrs)
return result[dn] if dn in result else None
def paged_search(
self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, sizelimit=None
):
def paged_search(self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None,
sizelimit=None):
""" Run a paged search on LDAP server """
assert not self.v2, "Paged search is not available on LDAP version 2"
assert self.con or self.connect()
# Set parameters default values (if not defined)
filterstr = filterstr if filterstr else "(objectClass=*)"
filterstr = filterstr if filterstr else '(objectClass=*)'
attrs = attrs if attrs else []
scope = scope if scope else "sub"
scope = scope if scope else 'sub'
pagesize = pagesize if pagesize else 500
# Initialize SimplePagedResultsControl object
page_control = SimplePagedResultsControl(
True, size=pagesize, cookie="" # Start without cookie
True,
size=pagesize,
cookie='' # Start without cookie
)
ret = {}
pages_count = 0
self.logger.debug(
"LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d"
" and attrs=%s",
"LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d and attrs=%s",
basedn,
filterstr,
scope,
pagesize,
attrs,
attrs
)
while True:
pages_count += 1
self.logger.debug(
"LdapServer - Paged search: request page %d with a maximum of %d objects"
" (current total count: %d)",
"LdapServer - Paged search: request page %d with a maximum of %d objects (current total count: %d)",
pages_count,
pagesize,
len(ret),
len(ret)
)
try:
res_id = self.con.search_ext(
basedn, self.get_scope(scope), filterstr, attrs, serverctrls=[page_control]
basedn,
self.get_scope(scope),
filterstr,
attrs,
serverctrls=[page_control]
)
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error running paged search on LDAP server: {e}", logging.CRITICAL
)
f'LdapServer - Error running paged search on LDAP server: {e}',
logging.CRITICAL)
return False
try:
# pylint: disable=unused-variable
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id)
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error pulling paged search result from LDAP server: {e}",
logging.CRITICAL,
)
f'LdapServer - Error pulling paged search result from LDAP server: {e}',
logging.CRITICAL)
return False
# Detect and catch PagedResultsControl answer from rctrls
@ -214,9 +214,8 @@ class LdapServer:
# If PagedResultsControl answer not detected, paged serach
if not result_page_control:
self._error(
"LdapServer - Server ignores RFC2696 control, paged search can not works",
logging.CRITICAL,
)
'LdapServer - Server ignores RFC2696 control, paged search can not works',
logging.CRITICAL)
return False
# Store results of this page
@ -237,12 +236,7 @@ class LdapServer:
# Otherwise, set cookie for the next search
page_control.cookie = result_page_control.cookie
self.logger.debug(
"LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)",
len(ret),
pages_count,
pagesize,
)
self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize)
return ret
def add_object(self, dn, attrs, encode=False):
@ -254,7 +248,7 @@ class LdapServer:
self.con.add_s(dn, ldif)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(f"LdapServer - Error adding {dn}: {e}", logging.ERROR)
self._error(f'LdapServer - Error adding {dn}: {e}', logging.ERROR)
return False
@ -264,7 +258,7 @@ class LdapServer:
ldif = modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [],
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
if not ldif:
return True
@ -277,8 +271,8 @@ class LdapServer:
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}", logging.ERROR
)
f'LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}',
logging.ERROR)
return False
@staticmethod
@ -287,7 +281,7 @@ class LdapServer:
ldif = modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [],
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
if not ldif:
return False
@ -299,50 +293,44 @@ class LdapServer:
return modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [],
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
@staticmethod
def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False):
"""
Format changes (modlist) on an object based on its old and new attributes values to
display/log it
"""
""" Format changes (modlist) on an object based on its old and new attributes values to display/log it """
msg = []
prefix = prefix if prefix else ""
for op, attr, val in modlist.modifyModlist(
prefix = prefix if prefix else ''
for (op, attr, val) in modlist.modifyModlist(
encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [],
ignore_attr_types=ignore_attrs if ignore_attrs else []
):
if op == ldap.MOD_ADD: # pylint: disable=no-member
op = "ADD"
op = 'ADD'
elif op == ldap.MOD_DELETE: # pylint: disable=no-member
op = "DELETE"
op = 'DELETE'
elif op == ldap.MOD_REPLACE: # pylint: disable=no-member
op = "REPLACE"
op = 'REPLACE'
else:
op = f"UNKNOWN (={op})"
if val is None and op == "DELETE":
msg.append(f"{prefix} - {op} {attr}")
op = f'UNKNOWN (={op})'
if val is None and op == 'DELETE':
msg.append(f'{prefix} - {op} {attr}')
else:
msg.append(f"{prefix} - {op} {attr}: {val}")
return "\n".join(msg)
msg.append(f'{prefix} - {op} {attr}: {val}')
return '\n'.join(msg)
def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True):
""" Rename an object in LDAP directory """
# If new_rdn is a complete DN, split new RDN and new superior DN
if len(explode_dn(new_rdn)) > 1:
self.logger.debug(
"LdapServer - Rename with a full new DN detected (%s): split new RDN and new"
" superior DN",
new_rdn,
"LdapServer - Rename with a full new DN detected (%s): split new RDN and new superior DN",
new_rdn
)
assert (
new_sup is None
), "You can't provide a complete DN as new_rdn and also provide new_sup parameter"
assert new_sup is None, "You can't provide a complete DN as new_rdn and also provide new_sup parameter"
new_dn_parts = explode_dn(new_rdn)
new_sup = ",".join(new_dn_parts[1:])
new_sup = ','.join(new_dn_parts[1:])
new_rdn = new_dn_parts[0]
assert self.con or self.connect()
try:
@ -351,16 +339,16 @@ class LdapServer:
dn,
new_rdn,
"same" if new_sup is None else new_sup,
delete_old,
delete_old
)
self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
f"LdapServer - Error renaming {dn} in {new_rdn} "
f'LdapServer - Error renaming {dn} in {new_rdn} '
f'(new superior: {"same" if new_sup is None else new_sup}, '
f"delete old: {delete_old}): {e}",
logging.ERROR,
f'delete old: {delete_old}): {e}',
logging.ERROR
)
return False
@ -373,7 +361,8 @@ class LdapServer:
self.con.delete_s(dn)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(f"LdapServer - Error deleting {dn}: {e}", logging.ERROR)
self._error(
f'LdapServer - Error deleting {dn}: {e}', logging.ERROR)
return False
@ -427,13 +416,11 @@ class LdapClient:
# Cache objects
_cached_objects = None
def __init__(
self, options=None, options_prefix=None, config=None, config_section=None, initialize=False
):
def __init__(self, options=None, options_prefix=None, config=None, config_section=None, initialize=False):
self._options = options if options else {}
self._options_prefix = options_prefix if options_prefix else "ldap_"
self._options_prefix = options_prefix if options_prefix else 'ldap_'
self._config = config if config else None
self._config_section = config_section if config_section else "ldap"
self._config_section = config_section if config_section else 'ldap'
self._cached_objects = {}
if initialize:
self.initialize()
@ -446,7 +433,7 @@ class LdapClient:
if self._config and self._config.defined(self._config_section, option):
return self._config.get(self._config_section, option)
assert not required, f"Options {option} not defined"
assert not required, f'Options {option} not defined'
return default
@ -454,45 +441,41 @@ class LdapClient:
def _just_try(self):
""" Check if just-try mode is enabled """
return self._get_option(
"just_try", default=(self._config.get_option("just_try") if self._config else False)
'just_try', default=(
self._config.get_option('just_try') if self._config
else False
)
)
def configure(self, comment=None, **kwargs):
""" Configure options on registered mylib.Config object """
assert self._config, (
"mylib.Config object not registered. Must be passed to __init__ as config keyword"
" argument."
)
assert self._config, "mylib.Config object not registered. Must be passed to __init__ as config keyword argument."
# Load configuration option types only here to avoid global
# dependency of ldap module with config one.
# pylint: disable=import-outside-toplevel
from mylib.config import BooleanOption, PasswordOption, StringOption
from mylib.config import BooleanOption, StringOption, PasswordOption
section = self._config.add_section(
self._config_section,
comment=comment if comment else "LDAP connection",
loaded_callback=self.initialize,
**kwargs,
)
comment=comment if comment else 'LDAP connection',
loaded_callback=self.initialize, **kwargs)
section.add_option(
StringOption, "uri", default="ldap://localhost", comment="LDAP server URI"
)
section.add_option(StringOption, "binddn", comment="LDAP Bind DN")
StringOption, 'uri', default='ldap://localhost',
comment='LDAP server URI')
section.add_option(
PasswordOption,
"bindpwd",
StringOption, 'binddn', comment='LDAP Bind DN')
section.add_option(
PasswordOption, 'bindpwd',
comment='LDAP Bind password (set to "keyring" to use XDG keyring)',
username_option="binddn",
keyring_value="keyring",
)
username_option='binddn', keyring_value='keyring')
section.add_option(
BooleanOption, "checkcert", default=True, comment="Check LDAP certificate"
)
BooleanOption, 'checkcert', default=True,
comment='Check LDAP certificate')
section.add_option(
BooleanOption, "disablereferral", default=False, comment="Disable referral following"
)
BooleanOption, 'disablereferral', default=False,
comment='Disable referral following')
return section
@ -500,16 +483,14 @@ class LdapClient:
""" Initialize LDAP connection """
if loaded_config:
self.config = loaded_config
uri = self._get_option("uri", required=True)
binddn = self._get_option("binddn")
log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else "annonymous")
uri = self._get_option('uri', required=True)
binddn = self._get_option('binddn')
log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else 'annonymous')
self._conn = LdapServer(
uri,
dn=binddn,
pwd=self._get_option("bindpwd"),
checkCert=self._get_option("checkcert"),
disableReferral=self._get_option("disablereferral"),
raiseOnError=True,
uri, dn=binddn, pwd=self._get_option('bindpwd'),
checkCert=self._get_option('checkcert'),
disableReferral=self._get_option('disablereferral'),
raiseOnError=True
)
# Reset cache
self._cached_objects = {}
@ -522,8 +503,8 @@ class LdapClient:
if isinstance(value, str):
return value
return value.decode(
self._get_option("encoding", default=DEFAULT_ENCODING),
self._get_option("encoding_error_policy", default="ignore"),
self._get_option('encoding', default=DEFAULT_ENCODING),
self._get_option('encoding_error_policy', default='ignore')
)
def encode(self, value):
@ -532,7 +513,7 @@ class LdapClient:
return [self.encode(v) for v in value]
if isinstance(value, bytes):
return value
return value.encode(self._get_option("encoding", default=DEFAULT_ENCODING))
return value.encode(self._get_option('encoding', default=DEFAULT_ENCODING))
def _get_obj(self, dn, attrs):
"""
@ -567,17 +548,8 @@ class LdapClient:
return vals if all_values else vals[0]
return default if default or not all_values else []
def get_objects(
self,
name,
filterstr,
basedn,
attrs,
key_attr=None,
warn=True,
paged_search=False,
pagesize=None,
):
def get_objects(self, name, filterstr, basedn, attrs, key_attr=None, warn=True,
paged_search=False, pagesize=None):
"""
Retrieve objects from LDAP
@ -596,28 +568,25 @@ class LdapClient:
(optional, default: see LdapServer.paged_search)
"""
if name in self._cached_objects:
log.debug("Retreived %s objects from cache", name)
log.debug('Retreived %s objects from cache', name)
else:
assert self._conn or self.initialize()
log.debug(
'Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn
)
log.debug('Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn)
if paged_search:
ldap_data = self._conn.paged_search(
basedn=basedn, filterstr=filterstr, attrs=attrs, pagesize=pagesize
basedn=basedn, filterstr=filterstr, attrs=attrs,
pagesize=pagesize
)
else:
ldap_data = self._conn.search(
basedn=basedn,
filterstr=filterstr,
attrs=attrs,
basedn=basedn, filterstr=filterstr, attrs=attrs,
)
if not ldap_data:
if warn:
log.warning("No %s found in LDAP", name)
log.warning('No %s found in LDAP', name)
else:
log.debug("No %s found in LDAP", name)
log.debug('No %s found in LDAP', name)
return {}
objects = {}
@ -627,12 +596,12 @@ class LdapClient:
continue
objects[obj_dn] = self._get_obj(obj_dn, obj_attrs)
self._cached_objects[name] = objects
if not key_attr or key_attr == "dn":
if not key_attr or key_attr == 'dn':
return self._cached_objects[name]
return {
self.get_attr(self._cached_objects[name][dn], key_attr): self._cached_objects[name][dn]
return dict(
(self.get_attr(self._cached_objects[name][dn], key_attr), self._cached_objects[name][dn])
for dn in self._cached_objects[name]
}
)
def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True):
"""
@ -651,14 +620,11 @@ class LdapClient:
(optional, default: True)
"""
assert self._conn or self.initialize()
log.debug(
'Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")',
type_name,
object_name,
filterstr,
basedn,
log.debug('Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")', type_name, object_name, filterstr, basedn)
ldap_data = self._conn.search(
basedn=basedn, filterstr=filterstr,
attrs=attrs
)
ldap_data = self._conn.search(basedn=basedn, filterstr=filterstr, attrs=attrs)
if not ldap_data:
if warn:
@ -669,8 +635,7 @@ class LdapClient:
if len(ldap_data) > 1:
raise LdapClientException(
f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}'
)
f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}')
dn = next(iter(ldap_data))
return self._get_obj(dn, ldap_data[dn])
@ -694,9 +659,9 @@ class LdapClient:
populate_cache_method()
if type_name not in self._cached_objects:
if warn:
log.warning("No %s found in LDAP", type_name)
log.warning('No %s found in LDAP', type_name)
else:
log.debug("No %s found in LDAP", type_name)
log.debug('No %s found in LDAP', type_name)
return None
if dn not in self._cached_objects[type_name]:
if warn:
@ -721,9 +686,7 @@ class LdapClient:
return value in cls.get_attr(obj, attr, all_values=True)
return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)]
def get_object_by_attr(
self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True
):
def get_object_by_attr(self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True):
"""
Retrieve an LDAP object specified by one of its attribute
@ -745,15 +708,15 @@ class LdapClient:
populate_cache_method()
if type_name not in self._cached_objects:
if warn:
log.warning("No %s found in LDAP", type_name)
log.warning('No %s found in LDAP', type_name)
else:
log.debug("No %s found in LDAP", type_name)
log.debug('No %s found in LDAP', type_name)
return None
matched = {
dn: obj
matched = dict(
(dn, obj)
for dn, obj in self._cached_objects[type_name].items()
if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive)
}
)
if not matched:
if warn:
log.warning('No %s found with %s="%s"', type_name, attr, value)
@ -763,8 +726,7 @@ class LdapClient:
if len(matched) > 1:
raise LdapClientException(
f'More than one {type_name} with {attr}="{value}" found: '
f'{" / ".join(matched.keys())}'
)
f'{" / ".join(matched.keys())}')
dn = next(iter(matched))
return matched[dn]
@ -780,7 +742,7 @@ class LdapClient:
old = {}
new = {}
protected_attrs = [a.lower() for a in protected_attrs or []]
protected_attrs.append("dn")
protected_attrs.append('dn')
# New/updated attributes
for attr in attrs:
if protected_attrs and attr.lower() in protected_attrs:
@ -793,11 +755,7 @@ class LdapClient:
# Deleted attributes
for attr in ldap_obj:
if (
(not protected_attrs or attr.lower() not in protected_attrs)
and ldap_obj[attr]
and attr not in attrs
):
if (not protected_attrs or attr.lower() not in protected_attrs) and ldap_obj[attr] and attr not in attrs:
old[attr] = self.encode(ldap_obj[attr])
if old == new:
return None
@ -813,7 +771,8 @@ class LdapClient:
"""
assert self._conn or self.initialize()
return self._conn.format_changes(
changes[0], changes[1], ignore_attrs=protected_attrs, prefix=prefix
changes[0], changes[1],
ignore_attrs=protected_attrs, prefix=prefix
)
def update_need(self, changes, protected_attrs=None):
@ -825,7 +784,10 @@ class LdapClient:
if changes is None:
return False
assert self._conn or self.initialize()
return self._conn.update_need(changes[0], changes[1], ignore_attrs=protected_attrs)
return self._conn.update_need(
changes[0], changes[1],
ignore_attrs=protected_attrs
)
def add_object(self, dn, attrs):
"""
@ -834,19 +796,21 @@ class LdapClient:
:param dn: The LDAP object DN
:param attrs: The LDAP object attributes (as dict)
"""
attrs = {attr: self.encode(values) for attr, values in attrs.items() if attr != "dn"}
attrs = dict(
(attr, self.encode(values))
for attr, values in attrs.items()
if attr != 'dn'
)
try:
if self._just_try:
log.debug("Just-try mode : do not really add object in LDAP")
log.debug('Just-try mode : do not really add object in LDAP')
return True
assert self._conn or self.initialize()
return self._conn.add_object(dn, attrs)
except LdapServerException:
log.error(
"An error occurred adding object %s in LDAP:\n%s\n",
dn,
pretty_format_dict(attrs),
exc_info=True,
dn, pretty_format_dict(attrs), exc_info=True
)
return False
@ -859,42 +823,35 @@ class LdapClient:
:param protected_attrs: An optional list of protected attributes
:param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected)
"""
assert (
isinstance(changes, (list, tuple))
and len(changes) == 2
and isinstance(changes[0], dict)
and isinstance(changes[1], dict)
), f"changes parameter must be a result of get_changes() method ({type(changes)} given)"
assert isinstance(changes, (list, tuple)) and len(changes) == 2 and isinstance(changes[0], dict) and isinstance(changes[1], dict), f'changes parameter must be a result of get_changes() method ({type(changes)} given)'
# In case of RDN change, we need to modify passed changes, copy it to make it unchanged in
# this case
_changes = copy.deepcopy(changes)
if not rdn_attr:
rdn_attr = ldap_obj["dn"].split("=")[0]
log.debug("Auto-detected RDN attribute from DN: %s => %s", ldap_obj["dn"], rdn_attr)
rdn_attr = ldap_obj['dn'].split('=')[0]
log.debug('Auto-detected RDN attribute from DN: %s => %s', ldap_obj['dn'], rdn_attr)
old_rdn_values = self.get_attr(_changes[0], rdn_attr, all_values=True)
new_rdn_values = self.get_attr(_changes[1], rdn_attr, all_values=True)
if old_rdn_values or new_rdn_values:
if not new_rdn_values:
log.error(
"%s : Attribute %s can't be deleted because it's used as RDN.",
ldap_obj["dn"],
rdn_attr,
ldap_obj['dn'], rdn_attr
)
return False
log.debug(
"%s: Changes detected on %s RDN attribute: must rename object before updating it",
ldap_obj["dn"],
rdn_attr,
'%s: Changes detected on %s RDN attribute: must rename object before updating it',
ldap_obj['dn'], rdn_attr
)
# Compute new object DN
dn_parts = explode_dn(self.decode(ldap_obj["dn"]))
basedn = ",".join(dn_parts[1:])
new_rdn = f"{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}"
new_dn = f"{new_rdn},{basedn}"
dn_parts = explode_dn(self.decode(ldap_obj['dn']))
basedn = ','.join(dn_parts[1:])
new_rdn = f'{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}'
new_dn = f'{new_rdn},{basedn}'
# Rename object
log.debug("%s: Rename to %s", ldap_obj["dn"], new_dn)
log.debug('%s: Rename to %s', ldap_obj['dn'], new_dn)
if not self.move_object(ldap_obj, new_dn):
return False
@ -908,29 +865,30 @@ class LdapClient:
# Check that there are other changes
if not _changes[0] and not _changes[1]:
log.debug("%s: No other change after renaming", new_dn)
log.debug('%s: No other change after renaming', new_dn)
return True
# Otherwise, update object DN
ldap_obj["dn"] = new_dn
ldap_obj['dn'] = new_dn
else:
log.debug("%s: No change detected on RDN attibute %s", ldap_obj["dn"], rdn_attr)
log.debug('%s: No change detected on RDN attibute %s', ldap_obj['dn'], rdn_attr)
try:
if self._just_try:
log.debug("Just-try mode : do not really update object in LDAP")
log.debug('Just-try mode : do not really update object in LDAP')
return True
assert self._conn or self.initialize()
return self._conn.update_object(
ldap_obj["dn"], _changes[0], _changes[1], ignore_attrs=protected_attrs
ldap_obj['dn'],
_changes[0],
_changes[1],
ignore_attrs=protected_attrs
)
except LdapServerException:
log.error(
"An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n",
ldap_obj["dn"],
pretty_format_dict(_changes[0]),
pretty_format_dict(_changes[1]),
exc_info=True,
ldap_obj['dn'], pretty_format_dict(_changes[0]), pretty_format_dict(_changes[1]),
exc_info=True
)
return False
@ -943,16 +901,14 @@ class LdapClient:
"""
try:
if self._just_try:
log.debug("Just-try mode : do not really move object in LDAP")
log.debug('Just-try mode : do not really move object in LDAP')
return True
assert self._conn or self.initialize()
return self._conn.rename_object(ldap_obj["dn"], new_dn_or_rdn)
return self._conn.rename_object(ldap_obj['dn'], new_dn_or_rdn)
except LdapServerException:
log.error(
"An error occurred moving object %s in LDAP (destination: %s)",
ldap_obj["dn"],
new_dn_or_rdn,
exc_info=True,
ldap_obj['dn'], new_dn_or_rdn, exc_info=True
)
return False
@ -964,12 +920,15 @@ class LdapClient:
"""
try:
if self._just_try:
log.debug("Just-try mode : do not really drop object in LDAP")
log.debug('Just-try mode : do not really drop object in LDAP')
return True
assert self._conn or self.initialize()
return self._conn.drop_object(ldap_obj["dn"])
return self._conn.drop_object(ldap_obj['dn'])
except LdapServerException:
log.error("An error occurred removing object %s in LDAP", ldap_obj["dn"], exc_info=True)
log.error(
"An error occurred removing object %s in LDAP",
ldap_obj['dn'], exc_info=True
)
return False
@ -984,29 +943,20 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
:param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date
string)
specific timezone (optional, default : timezone of the LDAP date string)
:param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : return naive datetime object (without timezone
conversion from LDAP)
:param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP)
"""
assert to_timezone is None or isinstance(
to_timezone, (datetime.tzinfo, str)
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
assert default_timezone is None or isinstance(
default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
), (
"default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
f" datetime.tzinfo object (not {type(default_timezone)})"
)
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})'
assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(default_timezone)})'
date = dateutil.parser.parse(value, dayfirst=False)
if not date.tzinfo:
if naive:
return date
if not default_timezone:
default_timezone = pytz.utc
elif default_timezone == "local":
elif default_timezone == 'local':
default_timezone = dateutil.tz.tzlocal()
elif isinstance(default_timezone, str):
default_timezone = pytz.timezone(default_timezone)
@ -1019,7 +969,7 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
elif naive:
return date.replace(tzinfo=None)
if to_timezone:
if to_timezone == "local":
if to_timezone == 'local':
to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone)
@ -1033,8 +983,7 @@ def parse_date(value, to_timezone=None, default_timezone=None, naive=True):
:param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date
string)
specific timezone (optional, default : timezone of the LDAP date string)
:param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : do not handle timezone conversion from LDAP
@ -1050,23 +999,13 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : datetime store as UTC in LDAP (without
conversion)
:param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion)
"""
assert isinstance(
value, datetime.datetime
), f"First parameter must be an datetime.datetime object (not {type(value)})"
assert from_timezone is None or isinstance(
from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
), (
"from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
f" datetime.tzinfo object (not {type(from_timezone)})"
)
assert to_timezone is None or isinstance(
to_timezone, (datetime.tzinfo, str)
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
assert isinstance(value, datetime.datetime), f'First parameter must be an datetime.datetime object (not {type(value)})'
assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(from_timezone)})'
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})'
if not value.tzinfo and not naive:
if not from_timezone or from_timezone == "local":
if not from_timezone or from_timezone == 'local':
from_timezone = dateutil.tz.tzlocal()
elif isinstance(from_timezone, str):
from_timezone = pytz.timezone(from_timezone)
@ -1082,14 +1021,14 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
from_value = copy.deepcopy(value)
if not to_timezone:
to_timezone = pytz.utc
elif to_timezone == "local":
elif to_timezone == 'local':
to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone)
to_value = from_value.astimezone(to_timezone) if not naive else from_value
datestring = to_value.strftime("%Y%m%d%H%M%S%z")
if datestring.endswith("+0000"):
datestring = datestring.replace("+0000", "Z")
datestring = to_value.strftime('%Y%m%d%H%M%S%z')
if datestring.endswith('+0000'):
datestring = datestring.replace('+0000', 'Z')
return datestring
@ -1101,16 +1040,8 @@ def format_date(value, from_timezone=None, to_timezone=None, naive=True):
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : do not handle timezone conversion before
formating and return datetime as UTC (because LDAP required a
timezone)
:param naive: Use naive datetime : do not handle timezone conversion before formating
and return datetime as UTC (because LDAP required a timezone)
"""
assert isinstance(
value, datetime.date
), f"First parameter must be an datetime.date object (not {type(value)})"
return format_datetime(
datetime.datetime.combine(value, datetime.datetime.min.time()),
from_timezone,
to_timezone,
naive,
)
assert isinstance(value, datetime.date), f'First parameter must be an datetime.date object (not {type(value)})'
return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive)

View file

@ -1,138 +0,0 @@
"""
My hash mapping library
Mapping configuration
{
'[dst key 1]': { # Key name in the result
'order': [int], # Processing order between destinations keys
# Source values
'other_key': [key], # Other key of the destination to use as source of values
'key' : '[src key]', # Key of source hash to get source values
'keys' : ['[sk1]', '[sk2]', ...], # List of source hash's keys to get source values
# Clean / convert values
'cleanRegex': '[regex]', # Regex that be use to remove unwanted characters. Ex : [^0-9+]
'convert': [function], # Function to use to convert value : Original value will be passed
# as argument and the value retrieve will replace source value in
# the result
# Ex :
# lambda x: x.strip()
# lambda x: "myformat : %s" % x
# Deduplicate / check values
'deduplicate': [bool], # If True, sources values will be depluplicated
'check': [function], # Function to use to check source value : Source value will be passed
# as argument and if function return True, the value will be preserved
# Ex :
# lambda x: x in my_global_hash
# Join values
'join': '[glue]', # If present, sources values will be join using the "glue"
# Alternative mapping
'or': { [map configuration] } # If this mapping case does not retreive any value, try to
# get value(s) with this other mapping configuration
},
'[dst key 2]': {
[...]
}
}
Return format :
{
'[dst key 1]': ['v1','v2', ...],
'[dst key 2]': [ ... ],
[...]
}
"""
import logging
import re
log = logging.getLogger(__name__)
def clean_value(value):
"""Clean value as encoded string"""
if isinstance(value, int):
value = str(value)
return value
def get_values(dst, dst_key, src, m):
"""Extract sources values"""
values = []
if "other_key" in m:
if m["other_key"] in dst:
values = dst[m["other_key"]]
if "key" in m:
if m["key"] in src and src[m["key"]] != "":
values.append(clean_value(src[m["key"]]))
if "keys" in m:
for key in m["keys"]:
if key in src and src[key] != "":
values.append(clean_value(src[key]))
# Clean and convert values
if "cleanRegex" in m and len(values) > 0:
new_values = []
for v in values:
nv = re.sub(m["cleanRegex"], "", v)
if nv != "":
new_values.append(nv)
values = new_values
if "convert" in m and len(values) > 0:
new_values = []
for v in values:
nv = m["convert"](v)
if nv != "":
new_values.append(nv)
values = new_values
# Deduplicate values
if m.get("deduplicate") and len(values) > 1:
new_values = []
for v in values:
if v not in new_values:
new_values.append(v)
values = new_values
# Check values
if "check" in m and len(values) > 0:
new_values = []
for v in values:
if m["check"](v):
new_values.append(v)
else:
log.debug("Invalid value %s for key %s", v, dst_key)
values = new_values
# Join values
if "join" in m and len(values) > 1:
values = [m["join"].join(values)]
# Manage alternative mapping case
if len(values) == 0 and "or" in m:
values = get_values(dst, dst_key, src, m["or"])
return values
def map_hash(mapping, src, dst=None):
"""Map hash"""
dst = dst if dst else {}
assert isinstance(dst, dict)
for dst_key in sorted(mapping.keys(), key=lambda x: mapping[x]["order"]):
values = get_values(dst, dst_key, src, mapping[dst_key])
if len(values) == 0:
if "required" in mapping[dst_key] and mapping[dst_key]["required"]:
log.debug(
"Destination key %s could not be filled from source but is required", dst_key
)
return False
continue
dst[dst_key] = values
return dst

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
""" MySQL client """
import logging
@ -6,7 +8,9 @@ import sys
import MySQLdb
from MySQLdb._exceptions import Error
from mylib.db import DB, DBFailToConnect
from mylib.db import DB
from mylib.db import DBFailToConnect
log = logging.getLogger(__name__)
@ -24,7 +28,7 @@ class MyDB(DB):
self._user = user
self._pwd = pwd
self._db = db
self._charset = charset if charset else "utf8"
self._charset = charset if charset else 'utf8'
super().__init__(**kwargs)
def connect(self, exit_on_error=True):
@ -32,25 +36,17 @@ class MyDB(DB):
if self._conn is None:
try:
self._conn = MySQLdb.connect(
host=self._host,
user=self._user,
passwd=self._pwd,
db=self._db,
charset=self._charset,
use_unicode=True,
)
host=self._host, user=self._user, passwd=self._pwd,
db=self._db, charset=self._charset, use_unicode=True)
except Error as err:
log.fatal(
"An error occured during MySQL database connection (%s@%s:%s).",
self._user,
self._host,
self._db,
exc_info=1,
'An error occured during MySQL database connection (%s@%s:%s).',
self._user, self._host, self._db, exc_info=1
)
if exit_on_error:
sys.exit(1)
else:
raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err
return True
def doSQL(self, sql, params=None):
@ -92,7 +88,10 @@ class MyDB(DB):
cursor = self._conn.cursor()
cursor.execute(sql, params)
return [
{field[0]: row[idx] for idx, field in enumerate(cursor.description)}
dict(
(field[0], row[idx])
for idx, field in enumerate(cursor.description)
)
for row in cursor.fetchall()
]
except Error:
@ -102,11 +101,13 @@ class MyDB(DB):
@staticmethod
def _quote_table_name(table):
""" Quote table name """
return "`{}`".format( # pylint: disable=consider-using-f-string
"`.`".join(table.split("."))
return '`{0}`'.format( # pylint: disable=consider-using-f-string
'`.`'.join(
table.split('.')
)
)
@staticmethod
def _quote_field_name(field):
""" Quote table name """
return f"`{field}`"
return f'`{field}`'

View file

@ -1,16 +1,18 @@
# -*- coding: utf-8 -*-
""" Opening hours helpers """
import datetime
import logging
import re
import time
import logging
log = logging.getLogger(__name__)
week_days = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"]
date_format = "%d/%m/%Y"
date_pattern = re.compile("^([0-9]{2})/([0-9]{2})/([0-9]{4})$")
time_pattern = re.compile("^([0-9]{1,2})h([0-9]{2})?$")
week_days = ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi', 'samedi', 'dimanche']
date_format = '%d/%m/%Y'
date_pattern = re.compile('^([0-9]{2})/([0-9]{2})/([0-9]{4})$')
time_pattern = re.compile('^([0-9]{1,2})h([0-9]{2})?$')
def easter_date(year):
@ -39,20 +41,20 @@ def nonworking_french_public_days_of_the_year(year=None):
year = datetime.date.today().year
dp = easter_date(year)
return {
"1janvier": datetime.date(year, 1, 1),
"paques": dp,
"lundi_paques": (dp + datetime.timedelta(1)),
"1mai": datetime.date(year, 5, 1),
"8mai": datetime.date(year, 5, 8),
"jeudi_ascension": (dp + datetime.timedelta(39)),
"pentecote": (dp + datetime.timedelta(49)),
"lundi_pentecote": (dp + datetime.timedelta(50)),
"14juillet": datetime.date(year, 7, 14),
"15aout": datetime.date(year, 8, 15),
"1novembre": datetime.date(year, 11, 1),
"11novembre": datetime.date(year, 11, 11),
"noel": datetime.date(year, 12, 25),
"saint_etienne": datetime.date(year, 12, 26),
'1janvier': datetime.date(year, 1, 1),
'paques': dp,
'lundi_paques': (dp + datetime.timedelta(1)),
'1mai': datetime.date(year, 5, 1),
'8mai': datetime.date(year, 5, 8),
'jeudi_ascension': (dp + datetime.timedelta(39)),
'pentecote': (dp + datetime.timedelta(49)),
'lundi_pentecote': (dp + datetime.timedelta(50)),
'14juillet': datetime.date(year, 7, 14),
'15aout': datetime.date(year, 8, 15),
'1novembre': datetime.date(year, 11, 1),
'11novembre': datetime.date(year, 11, 11),
'noel': datetime.date(year, 12, 25),
'saint_etienne': datetime.date(year, 12, 26),
}
@ -66,7 +68,7 @@ def parse_exceptional_closures(values):
for word in words:
if not word:
continue
parts = word.split("-")
parts = word.split('-')
if len(parts) == 1:
# ex: 31/02/2017
ptime = time.strptime(word, date_format)
@ -80,7 +82,7 @@ def parse_exceptional_closures(values):
pstart = time.strptime(parts[0], date_format)
pstop = time.strptime(parts[1], date_format)
if pstop <= pstart:
raise ValueError(f"Day {parts[1]} <= {parts[0]}")
raise ValueError(f'Day {parts[1]} <= {parts[0]}')
date = datetime.date(pstart.tm_year, pstart.tm_mon, pstart.tm_mday)
stop_date = datetime.date(pstop.tm_year, pstop.tm_mon, pstop.tm_mday)
@ -97,13 +99,13 @@ def parse_exceptional_closures(values):
hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0))
hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0))
if hstop <= hstart:
raise ValueError(f"Time {parts[1]} <= {parts[0]}")
hours_periods.append({"start": hstart, "stop": hstop})
raise ValueError(f'Time {parts[1]} <= {parts[0]}')
hours_periods.append({'start': hstart, 'stop': hstop})
else:
raise ValueError(f'Invalid number of part in this word: "{word}"')
if not days:
raise ValueError(f'No days found in value "{value}"')
exceptional_closures.append({"days": days, "hours_periods": hours_periods})
exceptional_closures.append({'days': days, 'hours_periods': hours_periods})
return exceptional_closures
@ -117,7 +119,7 @@ def parse_normal_opening_hours(values):
for word in words:
if not word:
continue
parts = word.split("-")
parts = word.split('-')
if len(parts) == 1:
# ex: jeudi
if word not in week_days:
@ -148,23 +150,20 @@ def parse_normal_opening_hours(values):
hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0))
hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0))
if hstop <= hstart:
raise ValueError(f"Time {parts[1]} <= {parts[0]}")
hours_periods.append({"start": hstart, "stop": hstop})
raise ValueError(f'Time {parts[1]} <= {parts[0]}')
hours_periods.append({'start': hstart, 'stop': hstop})
else:
raise ValueError(f'Invalid number of part in this word: "{word}"')
if not days and not hours_periods:
raise ValueError(f'No days or hours period found in this value: "{value}"')
normal_opening_hours.append({"days": days, "hours_periods": hours_periods})
normal_opening_hours.append({'days': days, 'hours_periods': hours_periods})
return normal_opening_hours
def is_closed(
normal_opening_hours_values=None,
exceptional_closures_values=None,
nonworking_public_holidays_values=None,
exceptional_closure_on_nonworking_public_days=False,
when=None,
on_error="raise",
normal_opening_hours_values=None, exceptional_closures_values=None,
nonworking_public_holidays_values=None, exceptional_closure_on_nonworking_public_days=False,
when=None, on_error='raise'
):
""" Check if closed """
if not when:
@ -173,26 +172,18 @@ def is_closed(
when_time = when.time()
when_weekday = week_days[when.timetuple().tm_wday]
on_error_result = None
if on_error == "closed":
if on_error == 'closed':
on_error_result = {
"closed": True,
"exceptional_closure": False,
"exceptional_closure_all_day": False,
}
elif on_error == "opened":
'closed': True, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
elif on_error == 'opened':
on_error_result = {
"closed": False,
"exceptional_closure": False,
"exceptional_closure_all_day": False,
}
'closed': False, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
log.debug(
"When = %s => date = %s / time = %s / week day = %s",
when,
when_date,
when_time,
when_weekday,
)
when, when_date, when_time, when_weekday)
if nonworking_public_holidays_values:
log.debug("Nonworking public holidays: %s", nonworking_public_holidays_values)
nonworking_days = nonworking_french_public_days_of_the_year()
@ -200,69 +191,65 @@ def is_closed(
if day in nonworking_days and when_date == nonworking_days[day]:
log.debug("Non working day: %s", day)
return {
"closed": True,
"exceptional_closure": exceptional_closure_on_nonworking_public_days,
"exceptional_closure_all_day": exceptional_closure_on_nonworking_public_days,
'closed': True,
'exceptional_closure': exceptional_closure_on_nonworking_public_days,
'exceptional_closure_all_day': exceptional_closure_on_nonworking_public_days
}
if exceptional_closures_values:
try:
exceptional_closures = parse_exceptional_closures(exceptional_closures_values)
log.debug("Exceptional closures: %s", exceptional_closures)
log.debug('Exceptional closures: %s', exceptional_closures)
except ValueError as e:
log.error("Fail to parse exceptional closures, consider as closed", exc_info=True)
if on_error_result is None:
raise e from e
return on_error_result
for cl in exceptional_closures:
if when_date not in cl["days"]:
log.debug("when_date (%s) no in days (%s)", when_date, cl["days"])
if when_date not in cl['days']:
log.debug("when_date (%s) no in days (%s)", when_date, cl['days'])
continue
if not cl["hours_periods"]:
if not cl['hours_periods']:
# All day exceptional closure
return {
"closed": True,
"exceptional_closure": True,
"exceptional_closure_all_day": True,
}
for hp in cl["hours_periods"]:
if hp["start"] <= when_time <= hp["stop"]:
'closed': True, 'exceptional_closure': True,
'exceptional_closure_all_day': True}
for hp in cl['hours_periods']:
if hp['start'] <= when_time <= hp['stop']:
return {
"closed": True,
"exceptional_closure": True,
"exceptional_closure_all_day": False,
}
'closed': True, 'exceptional_closure': True,
'exceptional_closure_all_day': False}
if normal_opening_hours_values:
try:
normal_opening_hours = parse_normal_opening_hours(normal_opening_hours_values)
log.debug("Normal opening hours: %s", normal_opening_hours)
log.debug('Normal opening hours: %s', normal_opening_hours)
except ValueError as e: # pylint: disable=broad-except
log.error("Fail to parse normal opening hours, consider as closed", exc_info=True)
if on_error_result is None:
raise e from e
return on_error_result
for oh in normal_opening_hours:
if oh["days"] and when_weekday not in oh["days"]:
log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh["days"])
if oh['days'] and when_weekday not in oh['days']:
log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh['days'])
continue
if not oh["hours_periods"]:
if not oh['hours_periods']:
# All day opened
return {
"closed": False,
"exceptional_closure": False,
"exceptional_closure_all_day": False,
}
for hp in oh["hours_periods"]:
if hp["start"] <= when_time <= hp["stop"]:
'closed': False, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
for hp in oh['hours_periods']:
if hp['start'] <= when_time <= hp['stop']:
return {
"closed": False,
"exceptional_closure": False,
"exceptional_closure_all_day": False,
}
'closed': False, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
log.debug("Not in normal opening hours => closed")
return {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
return {
'closed': True, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
# Not a nonworking day, not during exceptional closure and no normal opening
# hours defined => Opened
return {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False}
return {
'closed': False, 'exceptional_closure': False,
'exceptional_closure_all_day': False}

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
""" Oracle client """
import logging
@ -5,7 +7,8 @@ import sys
import cx_Oracle
from mylib.db import DB, DBFailToConnect
from mylib.db import DB
from mylib.db import DBFailToConnect
log = logging.getLogger(__name__)
@ -26,20 +29,22 @@ class OracleDB(DB):
def connect(self, exit_on_error=True):
""" Connect to Oracle server """
if self._conn is None:
log.info("Connect on Oracle server with DSN %s as %s", self._dsn, self._user)
log.info('Connect on Oracle server with DSN %s as %s', self._dsn, self._user)
try:
self._conn = cx_Oracle.connect(user=self._user, password=self._pwd, dsn=self._dsn)
self._conn = cx_Oracle.connect(
user=self._user,
password=self._pwd,
dsn=self._dsn
)
except cx_Oracle.Error as err:
log.fatal(
"An error occured during Oracle database connection (%s@%s).",
self._user,
self._dsn,
exc_info=1,
'An error occured during Oracle database connection (%s@%s).',
self._user, self._dsn, exc_info=1
)
if exit_on_error:
sys.exit(1)
else:
raise DBFailToConnect(f"{self._user}@{self._dsn}") from err
raise DBFailToConnect(f'{self._user}@{self._dsn}') from err
return True
def doSQL(self, sql, params=None):
@ -103,4 +108,4 @@ class OracleDB(DB):
@staticmethod
def format_param(param):
""" Format SQL query parameter for prepared query """
return f":{param}"
return f':{param}'

View file

@ -1,9 +1,11 @@
# coding: utf8
""" Progress bar """
import logging
import progressbar
log = logging.getLogger(__name__)
@ -23,15 +25,15 @@ class Pbar: # pylint: disable=useless-object-inheritance
self.__count = 0
self.__pbar = progressbar.ProgressBar(
widgets=[
name + ": ",
name + ': ',
progressbar.Percentage(),
" ",
' ',
progressbar.Bar(),
" ",
' ',
progressbar.SimpleProgress(),
progressbar.ETA(),
progressbar.ETA()
],
maxval=maxval,
maxval=maxval
).start()
else:
log.info(name)

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
""" PostgreSQL client """
import datetime
@ -19,8 +21,8 @@ class PgDB(DB):
_pwd = None
_db = None
date_format = "%Y-%m-%d"
datetime_format = "%Y-%m-%d %H:%M:%S"
date_format = '%Y-%m-%d'
datetime_format = '%Y-%m-%d %H:%M:%S'
def __init__(self, host, user, pwd, db, **kwargs):
self._host = host
@ -34,26 +36,23 @@ class PgDB(DB):
if self._conn is None:
try:
log.info(
"Connect on PostgreSQL server %s as %s on database %s",
self._host,
self._user,
self._db,
)
'Connect on PostgreSQL server %s as %s on database %s',
self._host, self._user, self._db)
self._conn = psycopg2.connect(
dbname=self._db, user=self._user, host=self._host, password=self._pwd
dbname=self._db,
user=self._user,
host=self._host,
password=self._pwd
)
except psycopg2.Error as err:
log.fatal(
"An error occured during Postgresql database connection (%s@%s, database=%s).",
self._user,
self._host,
self._db,
exc_info=1,
'An error occured during Postgresql database connection (%s@%s, database=%s).',
self._user, self._host, self._db, exc_info=1
)
if exit_on_error:
sys.exit(1)
else:
raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err
return True
def close(self):
@ -71,8 +70,7 @@ class PgDB(DB):
except psycopg2.Error:
log.error(
'An error occured setting Postgresql database connection encoding to "%s"',
enc,
exc_info=1,
enc, exc_info=1
)
return False
@ -126,7 +124,10 @@ class PgDB(DB):
@staticmethod
def _map_row_fields_by_index(fields, row):
return {field: row[idx] for idx, field in enumerate(fields)}
return dict(
(field, row[idx])
for idx, field in enumerate(fields)
)
#
# Depreated helpers
@ -136,7 +137,7 @@ class PgDB(DB):
def _quote_value(cls, value):
""" Quote a value for SQL query """
if value is None:
return "NULL"
return 'NULL'
if isinstance(value, (int, float)):
return str(value)
@ -147,7 +148,7 @@ class PgDB(DB):
value = cls._format_date(value)
# pylint: disable=consider-using-f-string
return "'{}'".format(value.replace("'", "''"))
return "'{0}'".format(value.replace("'", "''"))
@classmethod
def _format_datetime(cls, value):

View file

@ -1,24 +1,28 @@
# coding: utf8
""" Report """
import atexit
import logging
from mylib.config import ConfigurableObject, StringOption
from mylib.config import ConfigurableObject
from mylib.config import StringOption
from mylib.email import EmailClient
log = logging.getLogger(__name__)
class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
""" Logging report """
_config_name = "report"
_config_comment = "Email report"
_config_name = 'report'
_config_comment = 'Email report'
_defaults = {
"recipient": None,
"subject": "Report",
"loglevel": "WARNING",
"logformat": "%(asctime)s - %(levelname)s - %(message)s",
'recipient': None,
'subject': 'Report',
'loglevel': 'WARNING',
'logformat': '%(asctime)s - %(levelname)s - %(message)s',
}
content = []
@ -39,25 +43,17 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
""" Configure options on registered mylib.Config object """
section = super().configure(**kwargs)
section.add_option(StringOption, "recipient", comment="Report recipient email address")
section.add_option(
StringOption,
"subject",
default=self._defaults["subject"],
comment="Report email subject",
)
StringOption, 'recipient', comment='Report recipient email address')
section.add_option(
StringOption,
"loglevel",
default=self._defaults["loglevel"],
comment='Report log level (as accept by python logging, for instance "INFO")',
)
StringOption, 'subject', default=self._defaults['subject'],
comment='Report email subject')
section.add_option(
StringOption,
"logformat",
default=self._defaults["logformat"],
comment='Report log level (as accept by python logging, for instance "INFO")',
)
StringOption, 'loglevel', default=self._defaults['loglevel'],
comment='Report log level (as accept by python logging, for instance "INFO")')
section.add_option(
StringOption, 'logformat', default=self._defaults['logformat'],
comment='Report log level (as accept by python logging, for instance "INFO")')
if not self.email_client:
self.email_client = EmailClient(config=self._config)
@ -70,11 +66,12 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
super().initialize(loaded_config=loaded_config)
self.handler = logging.StreamHandler(self)
loglevel = self._get_option("loglevel").upper()
assert hasattr(logging, loglevel), f"Invalid report loglevel {loglevel}"
loglevel = self._get_option('loglevel').upper()
assert hasattr(logging, loglevel), (
f'Invalid report loglevel {loglevel}')
self.handler.setLevel(getattr(logging, loglevel))
self.formatter = logging.Formatter(self._get_option("logformat"))
self.formatter = logging.Formatter(self._get_option('logformat'))
self.handler.setFormatter(self.formatter)
def get_handler(self):
@ -100,34 +97,29 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
def send(self, subject=None, rcpt_to=None, email_client=None, just_try=False):
""" Send report using an EmailClient """
if rcpt_to is None:
rcpt_to = self._get_option("recipient")
rcpt_to = self._get_option('recipient')
if not rcpt_to:
log.debug("No report recipient, do not send report")
log.debug('No report recipient, do not send report')
return True
if subject is None:
subject = self._get_option("subject")
subject = self._get_option('subject')
assert subject, "You must provide report subject using Report.__init__ or Report.send"
if email_client is None:
email_client = self.email_client
assert email_client, (
"You must provide an email client __init__(), send() or send_at_exit() methods argument"
" email_client"
)
"You must provide an email client __init__(), send() or send_at_exit() methods argument email_client")
content = self.get_content()
if not content:
log.debug("Report is empty, do not send it")
log.debug('Report is empty, do not send it')
return True
msg = email_client.forge_message(
rcpt_to,
subject=subject,
text_body=content,
rcpt_to, subject=subject, text_body=content,
attachment_files=self._attachment_files,
attachment_payloads=self._attachment_payloads,
)
attachment_payloads=self._attachment_payloads)
if email_client.send(rcpt_to, msg=msg, just_try=just_try):
log.debug("Report sent to %s", rcpt_to)
log.debug('Report sent to %s', rcpt_to)
return True
log.error("Fail to send report to %s", rcpt_to)
log.error('Fail to send report to %s', rcpt_to)
return False
def send_at_exit(self, **kwargs):

View file

@ -1,14 +1,18 @@
# -*- coding: utf-8 -*-
""" Test Email client """
import datetime
import getpass
import logging
import sys
import getpass
from mako.template import Template as MakoTemplate
from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging
from mylib.scripts.helpers import get_opts_parser, add_email_opts
from mylib.scripts.helpers import init_logging, init_email_client
log = logging.getLogger("mylib.scripts.email_test")
log = logging.getLogger('mylib.scripts.email_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -20,11 +24,10 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True)
add_email_opts(parser)
test_opts = parser.add_argument_group("Test email options")
test_opts = parser.add_argument_group('Test email options')
test_opts.add_argument(
"-t",
"--to",
'-t', '--to',
action="store",
type=str,
dest="test_to",
@ -32,8 +35,7 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
)
test_opts.add_argument(
"-m",
"--mako",
'-m', '--mako',
action="store_true",
dest="test_mako",
help="Test mako templating",
@ -42,14 +44,14 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args()
if not options.test_to:
parser.error("You must specify test email recipient using -t/--to parameter")
parser.error('You must specify test email recipient using -t/--to parameter')
sys.exit(1)
# Initialize logs
init_logging(options, "Test EmailClient")
init_logging(options, 'Test EmailClient')
if options.email_smtp_user and not options.email_smtp_password:
options.email_smtp_password = getpass.getpass("Please enter SMTP password: ")
options.email_smtp_password = getpass.getpass('Please enter SMTP password: ')
email_client = init_email_client(
options,
@ -57,24 +59,20 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
test=dict(
subject="Test email",
text=(
"Just a test email sent at {sent_date}."
if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
"Just a test email sent at {sent_date}." if not options.test_mako else
MakoTemplate("Just a test email sent at ${sent_date}.")
),
html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
if not options.test_mako
else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>")
)
)
),
)
),
)
log.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
log.info("Test email sent")
log.info('Send a test email to %s', options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()):
log.info('Test email sent')
sys.exit(0)
log.error("Fail to send test email")
log.error('Fail to send test email')
sys.exit(1)

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
""" Test Email client using mylib.config.Config for configuration """
import datetime
import logging
@ -8,6 +10,7 @@ from mako.template import Template as MakoTemplate
from mylib.config import Config
from mylib.email import EmailClient
log = logging.getLogger(__name__)
@ -16,7 +19,7 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
if argv is None:
argv = sys.argv[1:]
config = Config(__doc__, __name__.replace(".", "_"))
config = Config(__doc__, __name__.replace('.', '_'))
email_client = EmailClient(config=config)
email_client.configure()
@ -24,11 +27,10 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
# Options parser
parser = config.get_arguments_parser(description=__doc__)
test_opts = parser.add_argument_group("Test email options")
test_opts = parser.add_argument_group('Test email options')
test_opts.add_argument(
"-t",
"--to",
'-t', '--to',
action="store",
type=str,
dest="test_to",
@ -36,8 +38,7 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
)
test_opts.add_argument(
"-m",
"--mako",
'-m', '--mako',
action="store_true",
dest="test_mako",
help="Test mako templating",
@ -46,30 +47,26 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = config.parse_arguments_options()
if not options.test_to:
parser.error("You must specify test email recipient using -t/--to parameter")
parser.error('You must specify test email recipient using -t/--to parameter')
sys.exit(1)
email_client.templates = dict(
test=dict(
subject="Test email",
text=(
"Just a test email sent at {sent_date}."
if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
"Just a test email sent at {sent_date}." if not options.test_mako else
MakoTemplate("Just a test email sent at ${sent_date}.")
),
html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
if not options.test_mako
else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>")
)
),
)
)
logging.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
logging.info("Test email sent")
logging.info('Send a test email to %s', options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()):
logging.info('Test email sent')
sys.exit(0)
logging.error("Fail to send test email")
logging.error('Fail to send test email')
sys.exit(1)

View file

@ -1,18 +1,20 @@
# coding: utf8
""" Scripts helpers """
import argparse
import getpass
import logging
import os.path
import socket
import sys
import os.path
log = logging.getLogger(__name__)
def init_logging(options, name, report=None):
""" Initialize logging from calling script options """
logformat = f"%(asctime)s - {name} - %(levelname)s - %(message)s"
logformat = f'%(asctime)s - {name} - %(levelname)s - %(message)s'
if options.debug:
loglevel = logging.DEBUG
elif options.verbose:
@ -44,44 +46,59 @@ def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, c
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode"
'-v', '--verbose',
action="store_true",
dest="verbose",
help="Enable verbose mode"
)
parser.add_argument(
"-d", "--debug", action="store_true", dest="debug", help="Enable debug mode"
'-d', '--debug',
action="store_true",
dest="debug",
help="Enable debug mode"
)
parser.add_argument(
"-l",
"--log-file",
'-l', '--log-file',
action="store",
type=str,
dest="logfile",
help=f'Log file path (default: {get_default_opt_value(config, default_config, "logfile")})',
default=get_default_opt_value(config, default_config, "logfile"),
help=(
'Log file path (default: '
f'{get_default_opt_value(config, default_config, "logfile")})'),
default=get_default_opt_value(config, default_config, 'logfile')
)
parser.add_argument(
"-C",
"--console",
'-C', '--console',
action="store_true",
dest="console",
help="Always log on console (even if log file is configured)",
help="Always log on console (even if log file is configured)"
)
if just_try:
parser.add_argument(
"-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode"
'-j', '--just-try',
action="store_true",
dest="just_try",
help="Enable just-try mode"
)
if just_one:
parser.add_argument(
"-J", "--just-one", action="store_true", dest="just_one", help="Enable just-one mode"
'-J', '--just-one',
action="store_true",
dest="just_one",
help="Enable just-one mode"
)
if progress:
parser.add_argument(
"-p", "--progress", action="store_true", dest="progress", help="Enable progress bar"
'-p', '--progress',
action="store_true",
dest="progress",
help="Enable progress bar"
)
return parser
@ -89,143 +106,120 @@ def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, c
def add_email_opts(parser, config=None):
""" Add email options """
email_opts = parser.add_argument_group("Email options")
email_opts = parser.add_argument_group('Email options')
default_config = dict(
smtp_host="127.0.0.1",
smtp_port=25,
smtp_ssl=False,
smtp_tls=False,
smtp_user=None,
smtp_password=None,
smtp_debug=False,
email_encoding=sys.getdefaultencoding(),
sender_name=getpass.getuser(),
sender_email=f"{getpass.getuser()}@{socket.gethostname()}",
catch_all=False,
smtp_host="127.0.0.1", smtp_port=25, smtp_ssl=False, smtp_tls=False, smtp_user=None,
smtp_password=None, smtp_debug=False, email_encoding=sys.getdefaultencoding(),
sender_name=getpass.getuser(), sender_email=f'{getpass.getuser()}@{socket.gethostname()}',
catch_all=False
)
email_opts.add_argument(
"--smtp-host",
'--smtp-host',
action="store",
type=str,
dest="email_smtp_host",
help=f'SMTP host (default: {get_default_opt_value(config, default_config, "smtp_host")})',
default=get_default_opt_value(config, default_config, "smtp_host"),
help=(
'SMTP host (default: '
f'{get_default_opt_value(config, default_config, "smtp_host")})'),
default=get_default_opt_value(config, default_config, 'smtp_host')
)
email_opts.add_argument(
"--smtp-port",
'--smtp-port',
action="store",
type=int,
dest="email_smtp_port",
help=f'SMTP port (default: {get_default_opt_value(config, default_config, "smtp_port")})',
default=get_default_opt_value(config, default_config, "smtp_port"),
default=get_default_opt_value(config, default_config, 'smtp_port')
)
email_opts.add_argument(
"--smtp-ssl",
'--smtp-ssl',
action="store_true",
dest="email_smtp_ssl",
help=f'Use SSL (default: {get_default_opt_value(config, default_config, "smtp_ssl")})',
default=get_default_opt_value(config, default_config, "smtp_ssl"),
default=get_default_opt_value(config, default_config, 'smtp_ssl')
)
email_opts.add_argument(
"--smtp-tls",
'--smtp-tls',
action="store_true",
dest="email_smtp_tls",
help=f'Use TLS (default: {get_default_opt_value(config, default_config, "smtp_tls")})',
default=get_default_opt_value(config, default_config, "smtp_tls"),
default=get_default_opt_value(config, default_config, 'smtp_tls')
)
email_opts.add_argument(
"--smtp-user",
'--smtp-user',
action="store",
type=str,
dest="email_smtp_user",
help=(
f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})'
),
default=get_default_opt_value(config, default_config, "smtp_user"),
help=f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})',
default=get_default_opt_value(config, default_config, 'smtp_user')
)
email_opts.add_argument(
"--smtp-password",
'--smtp-password',
action="store",
type=str,
dest="email_smtp_password",
help=(
"SMTP password (default:"
f' {get_default_opt_value(config, default_config, "smtp_password")})'
),
default=get_default_opt_value(config, default_config, "smtp_password"),
help=f'SMTP password (default: {get_default_opt_value(config, default_config, "smtp_password")})',
default=get_default_opt_value(config, default_config, 'smtp_password')
)
email_opts.add_argument(
"--smtp-debug",
'--smtp-debug',
action="store_true",
dest="email_smtp_debug",
help=(
"Debug SMTP connection (default:"
f' {get_default_opt_value(config, default_config, "smtp_debug")})'
),
default=get_default_opt_value(config, default_config, "smtp_debug"),
help=f'Debug SMTP connection (default: {get_default_opt_value(config, default_config, "smtp_debug")})',
default=get_default_opt_value(config, default_config, 'smtp_debug')
)
email_opts.add_argument(
"--email-encoding",
'--email-encoding',
action="store",
type=str,
dest="email_encoding",
help=(
"SMTP encoding (default:"
f' {get_default_opt_value(config, default_config, "email_encoding")})'
),
default=get_default_opt_value(config, default_config, "email_encoding"),
help=f'SMTP encoding (default: {get_default_opt_value(config, default_config, "email_encoding")})',
default=get_default_opt_value(config, default_config, 'email_encoding')
)
email_opts.add_argument(
"--sender-name",
'--sender-name',
action="store",
type=str,
dest="email_sender_name",
help=(
f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})'
),
default=get_default_opt_value(config, default_config, "sender_name"),
help=f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})',
default=get_default_opt_value(config, default_config, 'sender_name')
)
email_opts.add_argument(
"--sender-email",
'--sender-email',
action="store",
type=str,
dest="email_sender_email",
help=(
"Sender email (default:"
f' {get_default_opt_value(config, default_config, "sender_email")})'
),
default=get_default_opt_value(config, default_config, "sender_email"),
help=f'Sender email (default: {get_default_opt_value(config, default_config, "sender_email")})',
default=get_default_opt_value(config, default_config, 'sender_email')
)
email_opts.add_argument(
"--catch-all",
'--catch-all',
action="store",
type=str,
dest="email_catch_all",
help=(
"Catch all sent email: specify catch recipient email address "
f'(default: {get_default_opt_value(config, default_config, "catch_all")})'
),
default=get_default_opt_value(config, default_config, "catch_all"),
'Catch all sent email: specify catch recipient email address '
f'(default: {get_default_opt_value(config, default_config, "catch_all")})'),
default=get_default_opt_value(config, default_config, 'catch_all')
)
def init_email_client(options, **kwargs):
""" Initialize email client from calling script options """
from mylib.email import EmailClient # pylint: disable=import-outside-toplevel
log.info("Initialize Email client")
log.info('Initialize Email client')
return EmailClient(
smtp_host=options.email_smtp_host,
smtp_port=options.email_smtp_port,
@ -237,9 +231,9 @@ def init_email_client(options, **kwargs):
sender_name=options.email_sender_name,
sender_email=options.email_sender_email,
catch_all_addr=options.email_catch_all,
just_try=options.just_try if hasattr(options, "just_try") else False,
just_try=options.just_try if hasattr(options, 'just_try') else False,
encoding=options.email_encoding,
**kwargs,
**kwargs
)
@ -248,51 +242,53 @@ def add_sftp_opts(parser):
sftp_opts = parser.add_argument_group("SFTP options")
sftp_opts.add_argument(
"-H",
"--sftp-host",
'-H', '--sftp-host',
action="store",
type=str,
dest="sftp_host",
help="SFTP Host (default: localhost)",
default="localhost",
default='localhost'
)
sftp_opts.add_argument(
"--sftp-port",
'--sftp-port',
action="store",
type=int,
dest="sftp_port",
help="SFTP Port (default: 22)",
default=22,
default=22
)
sftp_opts.add_argument(
"-u", "--sftp-user", action="store", type=str, dest="sftp_user", help="SFTP User"
'-u', '--sftp-user',
action="store",
type=str,
dest="sftp_user",
help="SFTP User"
)
sftp_opts.add_argument(
"-P",
"--sftp-password",
'-P', '--sftp-password',
action="store",
type=str,
dest="sftp_password",
help="SFTP Password",
help="SFTP Password"
)
sftp_opts.add_argument(
"--sftp-known-hosts",
'--sftp-known-hosts',
action="store",
type=str,
dest="sftp_known_hosts",
help="SFTP known_hosts file path (default: ~/.ssh/known_hosts)",
default=os.path.expanduser("~/.ssh/known_hosts"),
default=os.path.expanduser('~/.ssh/known_hosts')
)
sftp_opts.add_argument(
"--sftp-auto-add-unknown-host-key",
'--sftp-auto-add-unknown-host-key',
action="store_true",
dest="sftp_auto_add_unknown_host_key",
help="Auto-add unknown SSH host key",
help="Auto-add unknown SSH host key"
)
return sftp_opts

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
""" Test LDAP """
import datetime
import logging
@ -6,10 +8,12 @@ import sys
import dateutil.tz
import pytz
from mylib.ldap import format_date, format_datetime, parse_date, parse_datetime
from mylib.scripts.helpers import get_opts_parser, init_logging
from mylib.ldap import format_datetime, format_date, parse_datetime, parse_date
from mylib.scripts.helpers import get_opts_parser
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.ldap_test")
log = logging.getLogger('mylib.scripts.ldap_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -22,121 +26,52 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args()
# Initialize logs
init_logging(options, "Test LDAP helpers")
init_logging(options, 'Test LDAP helpers')
now = datetime.datetime.now().replace(tzinfo=dateutil.tz.tzlocal())
print(f"Now = {now}")
print(f'Now = {now}')
datestring_now = format_datetime(now)
print(f"format_datetime : {datestring_now}")
print(
"format_datetime (from_timezone=utc) :"
f" {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}"
)
print(
"format_datetime (from_timezone=local) :"
f" {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}"
)
print(
"format_datetime (from_timezone=local) :"
f' {format_datetime(now.replace(tzinfo=None), from_timezone="local")}'
)
print(
"format_datetime (from_timezone=Paris) :"
f' {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}'
)
print(f"format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}")
print(
"format_datetime (to_timezone=local) :"
f" {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'format_datetime : {datestring_now}')
print(f'format_datetime (from_timezone=utc) : {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}')
print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}')
print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone="local")}')
print(f'format_datetime (from_timezone=Paris) : {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}')
print(f'format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}')
print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}')
print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone="local")}')
print(f'format_datetime (to_timezone=Tokyo) : {format_datetime(now, to_timezone="Asia/Tokyo")}')
print(f"format_datetime (naive=True) : {format_datetime(now, naive=True)}")
print(f'format_datetime (naive=True) : {format_datetime(now, naive=True)}')
print(f"format_date : {format_date(now)}")
print(
"format_date (from_timezone=utc) :"
f" {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}"
)
print(
"format_date (from_timezone=local) :"
f" {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}"
)
print(
"format_date (from_timezone=local) :"
f' {format_date(now.replace(tzinfo=None), from_timezone="local")}'
)
print(
"format_date (from_timezone=Paris) :"
f' {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}'
)
print(f"format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}")
print(
f"format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'format_date : {format_date(now)}')
print(f'format_date (from_timezone=utc) : {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}')
print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}')
print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone="local")}')
print(f'format_date (from_timezone=Paris) : {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}')
print(f'format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}')
print(f'format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}')
print(f'format_date (to_timezone=local) : {format_date(now, to_timezone="local")}')
print(f'format_date (to_timezone=Tokyo) : {format_date(now, to_timezone="Asia/Tokyo")}')
print(f"format_date (naive=True) : {format_date(now, naive=True)}")
print(f'format_date (naive=True) : {format_date(now, naive=True)}')
print(f"parse_datetime : {parse_datetime(datestring_now)}")
print(
"parse_datetime (default_timezone=utc) :"
f" {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}"
)
print(
"parse_datetime (default_timezone=local) :"
f" {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_datetime (default_timezone=local) :"
f' {parse_datetime(datestring_now[0:-1], default_timezone="local")}'
)
print(
"parse_datetime (default_timezone=Paris) :"
f' {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}'
)
print(
f"parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}"
)
print(
"parse_datetime (to_timezone=local) :"
f" {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_datetime (to_timezone=local) :"
f' {parse_datetime(datestring_now, to_timezone="local")}'
)
print(
"parse_datetime (to_timezone=Tokyo) :"
f' {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}'
)
print(f"parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}")
print(f'parse_datetime : {parse_datetime(datestring_now)}')
print(f'parse_datetime (default_timezone=utc) : {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}')
print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}')
print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone="local")}')
print(f'parse_datetime (default_timezone=Paris) : {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}')
print(f'parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}')
print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}')
print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone="local")}')
print(f'parse_datetime (to_timezone=Tokyo) : {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}')
print(f'parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}')
print(f"parse_date : {parse_date(datestring_now)}")
print(
"parse_date (default_timezone=utc) :"
f" {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}"
)
print(
"parse_date (default_timezone=local) :"
f" {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_date (default_timezone=local) :"
f' {parse_date(datestring_now[0:-1], default_timezone="local")}'
)
print(
"parse_date (default_timezone=Paris) :"
f' {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}'
)
print(f"parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}")
print(
"parse_date (to_timezone=local) :"
f" {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'parse_date : {parse_date(datestring_now)}')
print(f'parse_date (default_timezone=utc) : {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}')
print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}')
print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone="local")}')
print(f'parse_date (default_timezone=Paris) : {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}')
print(f'parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}')
print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}')
print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone="local")}')
print(
f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}'
)
print(f"parse_date (naive=True) : {parse_date(datestring_now, naive=True)}")
print(f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}')
print(f'parse_date (naive=True) : {parse_date(datestring_now, naive=True)}')

View file

@ -1,69 +0,0 @@
""" Test mapping """
import logging
import sys
from mylib import pretty_format_value
from mylib.mapping import map_hash
from mylib.scripts.helpers import get_opts_parser, init_logging
log = logging.getLogger(__name__)
def main(argv=None):
"""Script main"""
if argv is None:
argv = sys.argv[1:]
# Options parser
parser = get_opts_parser(progress=True)
options = parser.parse_args()
# Initialize logs
init_logging(options, "Test mapping")
src = {
"uid": "hmartin",
"firstname": "Martin",
"lastname": "Martin",
"disp_name": "Henri Martin",
"line_1": "3 rue de Paris",
"line_2": "Pour Pierre",
"zip_text": "92 120",
"city_text": "Montrouge",
"line_city": "92120 Montrouge",
"tel1": "01 00 00 00 00",
"tel2": "09 00 00 00 00",
"mobile": "06 00 00 00 00",
"fax": "01 00 00 00 00",
"email": "H.MARTIN@GMAIL.COM",
}
map_c = {
"uid": {"order": 0, "key": "uid", "required": True},
"givenName": {"order": 1, "key": "firstname"},
"sn": {"order": 2, "key": "lastname"},
"cn": {
"order": 3,
"key": "disp_name",
"required": True,
"or": {"attrs": ["firstname", "lastname"], "join": " "},
},
"displayName": {"order": 4, "other_key": "displayName"},
"street": {"order": 5, "join": " / ", "keys": ["ligne_1", "ligne_2"]},
"postalCode": {"order": 6, "key": "zip_text", "cleanRegex": "[^0-9]"},
"l": {"order": 7, "key": "city_text"},
"postalAddress": {"order": 8, "join": "$", "keys": ["ligne_1", "ligne_2", "ligne_city"]},
"telephoneNumber": {
"order": 9,
"keys": ["tel1", "tel2"],
"cleanRegex": "[^0-9+]",
"deduplicate": True,
},
"mobile": {"order": 10, "key": "mobile"},
"facsimileTelephoneNumber": {"order": 11, "key": "fax"},
"mail": {"order": 12, "key": "email", "convert": lambda x: x.lower().strip()},
}
print("Mapping source:\n" + pretty_format_value(src))
print("Mapping config:\n" + pretty_format_value(map_c))
print("Mapping result:\n" + pretty_format_value(map_hash(map_c, src)))

View file

@ -1,12 +1,16 @@
# -*- coding: utf-8 -*-
""" Test Progress bar """
import logging
import sys
import time
import sys
from mylib.pbar import Pbar
from mylib.scripts.helpers import get_opts_parser, init_logging
from mylib.scripts.helpers import get_opts_parser
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.pbar_test")
log = logging.getLogger('mylib.scripts.pbar_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -19,21 +23,20 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(progress=True)
parser.add_argument(
"-c",
"--count",
'-c', '--count',
action="store",
type=int,
dest="count",
help=f"Progress bar max value (default: {default_max_val})",
default=default_max_val,
help=f'Progress bar max value (default: {default_max_val})',
default=default_max_val
)
options = parser.parse_args()
# Initialize logs
init_logging(options, "Test Pbar")
init_logging(options, 'Test Pbar')
pbar = Pbar("Test", options.count, enabled=options.progress)
pbar = Pbar('Test', options.count, enabled=options.progress)
for idx in range(0, options.count): # pylint: disable=unused-variable
pbar.increment()

View file

@ -1,11 +1,15 @@
# -*- coding: utf-8 -*-
""" Test report """
import logging
import sys
from mylib.report import Report
from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging
from mylib.scripts.helpers import get_opts_parser, add_email_opts
from mylib.scripts.helpers import init_logging, init_email_client
log = logging.getLogger("mylib.scripts.report_test")
log = logging.getLogger('mylib.scripts.report_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -17,10 +21,14 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True)
add_email_opts(parser)
report_opts = parser.add_argument_group("Report options")
report_opts = parser.add_argument_group('Report options')
report_opts.add_argument(
"-t", "--to", action="store", type=str, dest="report_rcpt", help="Send report to this email"
'-t', '--to',
action="store",
type=str,
dest="report_rcpt",
help="Send report to this email"
)
options = parser.parse_args()
@ -29,13 +37,13 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser.error("You must specify a report recipient using -t/--to parameter")
# Initialize logs
report = Report(rcpt_to=options.report_rcpt, subject="Test report")
init_logging(options, "Test Report", report=report)
report = Report(rcpt_to=options.report_rcpt, subject='Test report')
init_logging(options, 'Test Report', report=report)
email_client = init_email_client(options)
report.send_at_exit(email_client=email_client)
logging.debug("Test debug message")
logging.info("Test info message")
logging.warning("Test warning message")
logging.error("Test error message")
logging.debug('Test debug message')
logging.info('Test info message')
logging.warning('Test warning message')
logging.error('Test error message')

View file

@ -1,17 +1,22 @@
# -*- coding: utf-8 -*-
""" Test SFTP client """
import atexit
import getpass
import tempfile
import logging
import sys
import os
import random
import string
import sys
import tempfile
from mylib.scripts.helpers import add_sftp_opts, get_opts_parser, init_logging
import getpass
from mylib.sftp import SFTPClient
from mylib.scripts.helpers import get_opts_parser, add_sftp_opts
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.sftp_test")
log = logging.getLogger('mylib.scripts.sftp_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -23,11 +28,10 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True)
add_sftp_opts(parser)
test_opts = parser.add_argument_group("Test SFTP options")
test_opts = parser.add_argument_group('Test SFTP options')
test_opts.add_argument(
"-p",
"--remote-upload-path",
'-p', '--remote-upload-path',
action="store",
type=str,
dest="upload_path",
@ -37,68 +41,66 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args()
# Initialize logs
init_logging(options, "Test SFTP client")
init_logging(options, 'Test SFTP client')
if options.sftp_user and not options.sftp_password:
options.sftp_password = getpass.getpass("Please enter SFTP password: ")
options.sftp_password = getpass.getpass('Please enter SFTP password: ')
log.info("Initialize Email client")
log.info('Initialize Email client')
sftp = SFTPClient(options=options, just_try=options.just_try)
sftp.connect()
atexit.register(sftp.close)
log.debug("Create tempory file")
test_content = b"Juste un test."
log.debug('Create tempory file')
test_content = b'Juste un test.'
tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
tmp_file = os.path.join(
tmp_dir.name, f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}'
tmp_dir.name,
f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}'
)
log.debug('Temporary file path: "%s"', tmp_file)
with open(tmp_file, "wb") as file_desc:
with open(tmp_file, 'wb') as file_desc:
file_desc.write(test_content)
log.debug(
"Upload file %s to SFTP server (in %s)",
tmp_file,
options.upload_path if options.upload_path else "remote initial connection directory",
)
'Upload file %s to SFTP server (in %s)', tmp_file,
options.upload_path if options.upload_path else "remote initial connection directory")
if not sftp.upload_file(tmp_file, options.upload_path):
log.error("Fail to upload test file on SFTP server")
log.error('Fail to upload test file on SFTP server')
sys.exit(1)
log.info("Test file uploaded on SFTP server")
log.info('Test file uploaded on SFTP server')
remote_filepath = (
os.path.join(options.upload_path, os.path.basename(tmp_file))
if options.upload_path
else os.path.basename(tmp_file)
if options.upload_path else os.path.basename(tmp_file)
)
with tempfile.NamedTemporaryFile() as tmp_file2:
log.info("Retrieve test file to %s", tmp_file2.name)
log.info('Retrieve test file to %s', tmp_file2.name)
if not sftp.get_file(remote_filepath, tmp_file2.name):
log.error("Fail to retrieve test file")
log.error('Fail to retrieve test file')
else:
with open(tmp_file2.name, "rb") as file_desc:
with open(tmp_file2.name, 'rb') as file_desc:
content = file_desc.read()
log.debug("Read content: %s", content)
log.debug('Read content: %s', content)
if test_content == content:
log.info("Content file retrieved match with uploaded one")
log.info('Content file retrieved match with uploaded one')
else:
log.error("Content file retrieved doest not match with uploaded one")
log.error('Content file retrieved doest not match with uploaded one')
try:
log.info("Remotly open test file %s", remote_filepath)
log.info('Remotly open test file %s', remote_filepath)
file_desc = sftp.open_file(remote_filepath)
content = file_desc.read()
log.debug("Read content: %s", content)
log.debug('Read content: %s', content)
if test_content == content:
log.info("Content of remote file match with uploaded one")
log.info('Content of remote file match with uploaded one')
else:
log.error("Content of remote file doest not match with uploaded one")
log.error('Content of remote file doest not match with uploaded one')
except Exception: # pylint: disable=broad-except
log.exception("An exception occurred remotly opening test file %s", remote_filepath)
log.exception('An exception occurred remotly opening test file %s', remote_filepath)
if sftp.remove_file(remote_filepath):
log.info("Test file removed on SFTP server")
log.info('Test file removed on SFTP server')
else:
log.error("Fail to remove test file on SFTP server")
log.error('Fail to remove test file on SFTP server')

View file

@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
""" SFTP client """
import logging
import os
from paramiko import AutoAddPolicy, SFTPAttributes, SSHClient
from paramiko import SSHClient, AutoAddPolicy, SFTPAttributes
from mylib.config import (
BooleanOption,
ConfigurableObject,
IntegerOption,
PasswordOption,
StringOption,
)
from mylib.config import ConfigurableObject
from mylib.config import BooleanOption
from mylib.config import IntegerOption
from mylib.config import PasswordOption
from mylib.config import StringOption
log = logging.getLogger(__name__)
@ -23,16 +23,16 @@ class SFTPClient(ConfigurableObject):
This class abstract all interactions with the SFTP server.
"""
_config_name = "sftp"
_config_comment = "SFTP"
_config_name = 'sftp'
_config_comment = 'SFTP'
_defaults = {
"host": "localhost",
"port": 22,
"user": None,
"password": None,
"known_hosts": os.path.expanduser("~/.ssh/known_hosts"),
"auto_add_unknown_host_key": False,
"just_try": False,
'host': 'localhost',
'port': 22,
'user': None,
'password': None,
'known_hosts': os.path.expanduser('~/.ssh/known_hosts'),
'auto_add_unknown_host_key': False,
'just_try': False,
}
ssh_client = None
@ -45,48 +45,30 @@ class SFTPClient(ConfigurableObject):
section = super().configure(**kwargs)
section.add_option(
StringOption,
"host",
default=self._defaults["host"],
comment="SFTP server hostname/IP address",
)
StringOption, 'host', default=self._defaults['host'],
comment='SFTP server hostname/IP address')
section.add_option(
IntegerOption, "port", default=self._defaults["port"], comment="SFTP server port"
)
IntegerOption, 'port', default=self._defaults['port'],
comment='SFTP server port')
section.add_option(
StringOption,
"user",
default=self._defaults["user"],
comment="SFTP authentication username",
)
StringOption, 'user', default=self._defaults['user'],
comment='SFTP authentication username')
section.add_option(
PasswordOption,
"password",
default=self._defaults["password"],
PasswordOption, 'password', default=self._defaults['password'],
comment='SFTP authentication password (set to "keyring" to use XDG keyring)',
username_option="user",
keyring_value="keyring",
)
username_option='user', keyring_value='keyring')
section.add_option(
StringOption,
"known_hosts",
default=self._defaults["known_hosts"],
comment="SFTP known_hosts filepath",
)
StringOption, 'known_hosts', default=self._defaults['known_hosts'],
comment='SFTP known_hosts filepath')
section.add_option(
BooleanOption,
"auto_add_unknown_host_key",
default=self._defaults["auto_add_unknown_host_key"],
comment="Auto add unknown host key",
)
BooleanOption, 'auto_add_unknown_host_key',
default=self._defaults['auto_add_unknown_host_key'],
comment='Auto add unknown host key')
if just_try:
section.add_option(
BooleanOption,
"just_try",
default=self._defaults["just_try"],
comment="Just-try mode: do not really make change on remote SFTP host",
)
BooleanOption, 'just_try', default=self._defaults['just_try'],
comment='Just-try mode: do not really make change on remote SFTP host')
return section
@ -98,20 +80,19 @@ class SFTPClient(ConfigurableObject):
""" Connect to SFTP server """
if self.ssh_client:
return
host = self._get_option("host")
port = self._get_option("port")
host = self._get_option('host')
port = self._get_option('port')
log.info("Connect to SFTP server %s:%d", host, port)
self.ssh_client = SSHClient()
if self._get_option("known_hosts"):
self.ssh_client.load_host_keys(self._get_option("known_hosts"))
if self._get_option("auto_add_unknown_host_key"):
log.debug("Set missing host key policy to auto-add")
if self._get_option('known_hosts'):
self.ssh_client.load_host_keys(self._get_option('known_hosts'))
if self._get_option('auto_add_unknown_host_key'):
log.debug('Set missing host key policy to auto-add')
self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
self.ssh_client.connect(
host,
port=port,
username=self._get_option("user"),
password=self._get_option("password"),
host, port=port,
username=self._get_option('user'),
password=self._get_option('password')
)
self.sftp_client = self.ssh_client.open_sftp()
self.initial_directory = self.sftp_client.getcwd()
@ -127,7 +108,7 @@ class SFTPClient(ConfigurableObject):
log.debug("Retreive file '%s' to '%s'", remote_filepath, local_filepath)
return self.sftp_client.get(remote_filepath, local_filepath) is None
def open_file(self, remote_filepath, mode="r"):
def open_file(self, remote_filepath, mode='r'):
""" Remotly open a file on SFTP server """
self.connect()
log.debug("Remotly open file '%s'", remote_filepath)
@ -138,13 +119,13 @@ class SFTPClient(ConfigurableObject):
self.connect()
remote_filepath = os.path.join(
remote_directory if remote_directory else self.initial_directory,
os.path.basename(filepath),
os.path.basename(filepath)
)
log.debug("Upload file '%s' to '%s'", filepath, remote_filepath)
if self._get_option("just_try"):
if self._get_option('just_try'):
log.debug(
"Just-try mode: do not really upload file '%s' to '%s'", filepath, remote_filepath
)
"Just-try mode: do not really upload file '%s' to '%s'",
filepath, remote_filepath)
return True
result = self.sftp_client.put(filepath, remote_filepath)
return isinstance(result, SFTPAttributes)
@ -153,7 +134,7 @@ class SFTPClient(ConfigurableObject):
""" Remove a file on SFTP server """
self.connect()
log.debug("Remove file '%s'", filepath)
if self._get_option("just_try"):
if self._get_option('just_try'):
log.debug("Just - try mode: do not really remove file '%s'", filepath)
return True
return self.sftp_client.remove(filepath) is None

View file

@ -13,15 +13,15 @@ class TelltaleFile:
def __init__(self, filepath=None, filename=None, dirpath=None):
assert filepath or filename, "filename or filepath is required"
if filepath:
assert (
not filename or os.path.basename(filepath) == filename
), "filepath and filename does not match"
assert (
not dirpath or os.path.dirname(filepath) == dirpath
), "filepath and dirpath does not match"
assert not filename or os.path.basename(filepath) == filename, "filepath and filename does not match"
assert not dirpath or os.path.dirname(filepath) == dirpath, "filepath and dirpath does not match"
self.filename = filename if filename else os.path.basename(filepath)
self.dirpath = (
dirpath if dirpath else (os.path.dirname(filepath) if filepath else os.getcwd())
dirpath if dirpath
else (
os.path.dirname(filepath) if filepath
else os.getcwd()
)
)
self.filepath = filepath if filepath else os.path.join(self.dirpath, self.filename)
@ -29,19 +29,21 @@ class TelltaleFile:
def last_update(self):
""" Retreive last update datetime of the telltall file """
try:
return datetime.datetime.fromtimestamp(os.stat(self.filepath).st_mtime)
return datetime.datetime.fromtimestamp(
os.stat(self.filepath).st_mtime
)
except FileNotFoundError:
log.info("Telltale file not found (%s)", self.filepath)
log.info('Telltale file not found (%s)', self.filepath)
return None
def update(self):
""" Update the telltale file """
log.info("Update telltale file (%s)", self.filepath)
log.info('Update telltale file (%s)', self.filepath)
try:
os.utime(self.filepath, None)
except FileNotFoundError:
# pylint: disable=consider-using-with
open(self.filepath, "a", encoding="utf-8").close()
open(self.filepath, 'a', encoding="utf-8").close()
def remove(self):
""" Remove the telltale file """

View file

@ -1,3 +1,2 @@
[flake8]
ignore = E501,W503
max-line-length = 100

View file

@ -1,74 +1,75 @@
#!/usr/bin/env python
"""Setuptools script"""
# -*- coding: utf-8 -*-
from setuptools import find_packages
from setuptools import setup
from setuptools import find_packages, setup
extras_require = {
"dev": [
"pytest",
"mocker",
"pytest-mock",
"pylint == 2.15.10",
"pre-commit",
'dev': [
'pytest',
'mocker',
'pytest-mock',
'pylint',
'flake8',
],
"config": [
"argcomplete",
"keyring",
"systemd-python",
'config': [
'argcomplete',
'keyring',
'systemd-python',
],
"ldap": [
"python-ldap",
"python-dateutil",
"pytz",
'ldap': [
'python-ldap',
'python-dateutil',
'pytz',
],
"email": [
"mako",
'email': [
'mako',
],
"pgsql": [
"psycopg2",
'pgsql': [
'psycopg2',
],
"oracle": [
"cx_Oracle",
'oracle': [
'cx_Oracle',
],
"mysql": [
"mysqlclient",
'mysql': [
'mysqlclient',
],
"sftp": [
"paramiko",
'sftp': [
'paramiko',
],
}
install_requires = ["progressbar"]
install_requires = ['progressbar']
for extra, deps in extras_require.items():
if extra != "dev":
if extra != 'dev':
install_requires.extend(deps)
version = "0.1"
version = '0.1'
setup(
name="mylib",
version=version,
description="A set of helpers small libs to make common tasks easier in my script development",
description='A set of helpers small libs to make common tasks easier in my script development',
classifiers=[
"Programming Language :: Python",
'Programming Language :: Python',
],
install_requires=install_requires,
extras_require=extras_require,
author="Benjamin Renard",
author_email="brenard@zionetrix.net",
url="https://gogs.zionetrix.net/bn8/python-mylib",
author='Benjamin Renard',
author_email='brenard@zionetrix.net',
url='https://gogs.zionetrix.net/bn8/python-mylib',
packages=find_packages(),
include_package_data=True,
zip_safe=False,
entry_points={
"console_scripts": [
"mylib-test-email = mylib.scripts.email_test:main",
"mylib-test-email-with-config = mylib.scripts.email_test_with_config:main",
"mylib-test-map = mylib.scripts.map_test:main",
"mylib-test-pbar = mylib.scripts.pbar_test:main",
"mylib-test-report = mylib.scripts.report_test:main",
"mylib-test-ldap = mylib.scripts.ldap_test:main",
"mylib-test-sftp = mylib.scripts.sftp_test:main",
'console_scripts': [
'mylib-test-email = mylib.scripts.email_test:main',
'mylib-test-email-with-config = mylib.scripts.email_test_with_config:main',
'mylib-test-pbar = mylib.scripts.pbar_test:main',
'mylib-test-report = mylib.scripts.report_test:main',
'mylib-test-ldap = mylib.scripts.ldap_test:main',
'mylib-test-sftp = mylib.scripts.sftp_test:main',
],
},
)

View file

@ -22,11 +22,18 @@ echo "Install package with dev dependencies using pip..."
$VENV/bin/python3 -m pip install -e ".[dev]" $QUIET_ARG
RES=0
# Run tests
$VENV/bin/python3 -m pytest tests
[ $? -ne 0 ] && RES=1
# Run pre-commit
echo "Run pre-commit..."
source $VENV/bin/activate
pre-commit run --all-files
# Run pylint
echo "Run pylint..."
$VENV/bin/pylint --extension-pkg-whitelist=cx_Oracle mylib tests
[ $? -ne 0 ] && RES=1
# Run flake8
echo "Run flake8..."
$VENV/bin/flake8 mylib tests
[ $? -ne 0 ] && RES=1
# Clean temporary venv

View file

@ -2,29 +2,31 @@
# pylint: disable=global-variable-not-assigned
""" Tests on config lib """
import configparser
import logging
import os
import configparser
import pytest
from mylib.config import BooleanOption, Config, ConfigSection, StringOption
from mylib.config import Config, ConfigSection
from mylib.config import BooleanOption
from mylib.config import StringOption
runned = {}
def test_config_init_default_args():
appname = "Test app"
appname = 'Test app'
config = Config(appname)
assert config.appname == appname
assert config.version == "0.0"
assert config.encoding == "utf-8"
assert config.version == '0.0'
assert config.encoding == 'utf-8'
def test_config_init_custom_args():
appname = "Test app"
version = "1.43"
encoding = "ISO-8859-1"
appname = 'Test app'
version = '1.43'
encoding = 'ISO-8859-1'
config = Config(appname, version=version, encoding=encoding)
assert config.appname == appname
assert config.version == version
@ -32,8 +34,8 @@ def test_config_init_custom_args():
def test_add_section_default_args():
config = Config("Test app")
name = "test_section"
config = Config('Test app')
name = 'test_section'
section = config.add_section(name)
assert isinstance(section, ConfigSection)
assert config.sections[name] == section
@ -43,9 +45,9 @@ def test_add_section_default_args():
def test_add_section_custom_args():
config = Config("Test app")
name = "test_section"
comment = "Test"
config = Config('Test app')
name = 'test_section'
comment = 'Test'
order = 20
section = config.add_section(name, comment=comment, order=order)
assert isinstance(section, ConfigSection)
@ -55,47 +57,47 @@ def test_add_section_custom_args():
def test_add_section_with_callback():
config = Config("Test app")
name = "test_section"
config = Config('Test app')
name = 'test_section'
global runned
runned["test_add_section_with_callback"] = False
runned['test_add_section_with_callback'] = False
def test_callback(loaded_config):
global runned
assert loaded_config == config
assert runned["test_add_section_with_callback"] is False
runned["test_add_section_with_callback"] = True
assert runned['test_add_section_with_callback'] is False
runned['test_add_section_with_callback'] = True
section = config.add_section(name, loaded_callback=test_callback)
assert isinstance(section, ConfigSection)
assert test_callback in config._loaded_callbacks
assert runned["test_add_section_with_callback"] is False
assert runned['test_add_section_with_callback'] is False
config.parse_arguments_options(argv=[], create=False)
assert runned["test_add_section_with_callback"] is True
assert runned['test_add_section_with_callback'] is True
assert test_callback in config._loaded_callbacks_executed
# Try to execute again to verify callback is not runned again
config._loaded()
def test_add_section_with_callback_already_loaded():
config = Config("Test app")
name = "test_section"
config = Config('Test app')
name = 'test_section'
config.parse_arguments_options(argv=[], create=False)
global runned
runned["test_add_section_with_callback_already_loaded"] = False
runned['test_add_section_with_callback_already_loaded'] = False
def test_callback(loaded_config):
global runned
assert loaded_config == config
assert runned["test_add_section_with_callback_already_loaded"] is False
runned["test_add_section_with_callback_already_loaded"] = True
assert runned['test_add_section_with_callback_already_loaded'] is False
runned['test_add_section_with_callback_already_loaded'] = True
section = config.add_section(name, loaded_callback=test_callback)
assert isinstance(section, ConfigSection)
assert runned["test_add_section_with_callback_already_loaded"] is True
assert runned['test_add_section_with_callback_already_loaded'] is True
assert test_callback in config._loaded_callbacks
assert test_callback in config._loaded_callbacks_executed
# Try to execute again to verify callback is not runned again
@ -103,10 +105,10 @@ def test_add_section_with_callback_already_loaded():
def test_add_option_default_args():
config = Config("Test app")
section = config.add_section("my_section")
config = Config('Test app')
section = config.add_section('my_section')
assert isinstance(section, ConfigSection)
name = "my_option"
name = 'my_option'
option = section.add_option(StringOption, name)
assert isinstance(option, StringOption)
assert name in section.options and section.options[name] == option
@ -122,17 +124,17 @@ def test_add_option_default_args():
def test_add_option_custom_args():
config = Config("Test app")
section = config.add_section("my_section")
config = Config('Test app')
section = config.add_section('my_section')
assert isinstance(section, ConfigSection)
name = "my_option"
name = 'my_option'
kwargs = dict(
default="default value",
comment="my comment",
default='default value',
comment='my comment',
no_arg=True,
arg="--my-option",
short_arg="-M",
arg_help="My help",
arg='--my-option',
short_arg='-M',
arg_help='My help'
)
option = section.add_option(StringOption, name, **kwargs)
assert isinstance(option, StringOption)
@ -146,12 +148,12 @@ def test_add_option_custom_args():
def test_defined():
config = Config("Test app")
section_name = "my_section"
opt_name = "my_option"
config = Config('Test app')
section_name = 'my_section'
opt_name = 'my_option'
assert not config.defined(section_name, opt_name)
section = config.add_section("my_section")
section = config.add_section('my_section')
assert isinstance(section, ConfigSection)
section.add_option(StringOption, opt_name)
@ -159,29 +161,29 @@ def test_defined():
def test_isset():
config = Config("Test app")
section_name = "my_section"
opt_name = "my_option"
config = Config('Test app')
section_name = 'my_section'
opt_name = 'my_option'
assert not config.isset(section_name, opt_name)
section = config.add_section("my_section")
section = config.add_section('my_section')
assert isinstance(section, ConfigSection)
option = section.add_option(StringOption, opt_name)
assert not config.isset(section_name, opt_name)
config.parse_arguments_options(argv=[option.parser_argument_name, "value"], create=False)
config.parse_arguments_options(argv=[option.parser_argument_name, 'value'], create=False)
assert config.isset(section_name, opt_name)
def test_not_isset():
config = Config("Test app")
section_name = "my_section"
opt_name = "my_option"
config = Config('Test app')
section_name = 'my_section'
opt_name = 'my_option'
assert not config.isset(section_name, opt_name)
section = config.add_section("my_section")
section = config.add_section('my_section')
assert isinstance(section, ConfigSection)
section.add_option(StringOption, opt_name)
@ -193,11 +195,11 @@ def test_not_isset():
def test_get():
config = Config("Test app")
section_name = "my_section"
opt_name = "my_option"
opt_value = "value"
section = config.add_section("my_section")
config = Config('Test app')
section_name = 'my_section'
opt_name = 'my_option'
opt_value = 'value'
section = config.add_section('my_section')
option = section.add_option(StringOption, opt_name)
config.parse_arguments_options(argv=[option.parser_argument_name, opt_value], create=False)
@ -205,11 +207,11 @@ def test_get():
def test_get_default():
config = Config("Test app")
section_name = "my_section"
opt_name = "my_option"
opt_default_value = "value"
section = config.add_section("my_section")
config = Config('Test app')
section_name = 'my_section'
opt_name = 'my_option'
opt_default_value = 'value'
section = config.add_section('my_section')
section.add_option(StringOption, opt_name, default=opt_default_value)
config.parse_arguments_options(argv=[], create=False)
@ -217,8 +219,8 @@ def test_get_default():
def test_logging_splited_stdout_stderr(capsys):
config = Config("Test app")
config.parse_arguments_options(argv=["-C", "-v"], create=False)
config = Config('Test app')
config.parse_arguments_options(argv=['-C', '-v'], create=False)
info_msg = "[info]"
err_msg = "[error]"
logging.getLogger().info(info_msg)
@ -237,9 +239,9 @@ def test_logging_splited_stdout_stderr(capsys):
@pytest.fixture()
def config_with_file(tmpdir):
config = Config("Test app")
config_dir = tmpdir.mkdir("config")
config_file = config_dir.join("config.ini")
config = Config('Test app')
config_dir = tmpdir.mkdir('config')
config_file = config_dir.join('config.ini')
config.save(os.path.join(config_file.dirname, config_file.basename))
return config
@ -248,7 +250,6 @@ def generate_mock_input(expected_prompt, input_value):
def mock_input(self, prompt): # pylint: disable=unused-argument
assert prompt == expected_prompt
return input_value
return mock_input
@ -256,9 +257,10 @@ def generate_mock_input(expected_prompt, input_value):
def test_boolean_option_from_config(config_with_file):
section = config_with_file.add_section("test")
section = config_with_file.add_section('test')
default = True
option = section.add_option(BooleanOption, "test_bool", default=default)
option = section.add_option(
BooleanOption, 'test_bool', default=default)
config_with_file.save()
option.set(not default)
@ -271,76 +273,74 @@ def test_boolean_option_from_config(config_with_file):
def test_boolean_option_ask_value(mocker):
config = Config("Test app")
section = config.add_section("test")
name = "test_bool"
option = section.add_option(BooleanOption, name, default=True)
config = Config('Test app')
section = config.add_section('test')
name = 'test_bool'
option = section.add_option(
BooleanOption, name, default=True)
mocker.patch(
"mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "y")
'mylib.config.BooleanOption._get_user_input',
generate_mock_input(f'{name}: [Y/n] ', 'y')
)
assert option.ask_value(set_it=False) is True
mocker.patch(
"mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "Y")
'mylib.config.BooleanOption._get_user_input',
generate_mock_input(f'{name}: [Y/n] ', 'Y')
)
assert option.ask_value(set_it=False) is True
mocker.patch(
"mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "")
'mylib.config.BooleanOption._get_user_input',
generate_mock_input(f'{name}: [Y/n] ', '')
)
assert option.ask_value(set_it=False) is True
mocker.patch(
"mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "n")
'mylib.config.BooleanOption._get_user_input',
generate_mock_input(f'{name}: [Y/n] ', 'n')
)
assert option.ask_value(set_it=False) is False
mocker.patch(
"mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "N")
'mylib.config.BooleanOption._get_user_input',
generate_mock_input(f'{name}: [Y/n] ', 'N')
)
assert option.ask_value(set_it=False) is False
def test_boolean_option_to_config():
config = Config("Test app")
section = config.add_section("test")
config = Config('Test app')
section = config.add_section('test')
default = True
option = section.add_option(BooleanOption, "test_bool", default=default)
assert option.to_config(True) == "true"
assert option.to_config(False) == "false"
option = section.add_option(BooleanOption, 'test_bool', default=default)
assert option.to_config(True) == 'true'
assert option.to_config(False) == 'false'
def test_boolean_option_export_to_config(config_with_file):
section = config_with_file.add_section("test")
name = "test_bool"
comment = "Test boolean"
section = config_with_file.add_section('test')
name = 'test_bool'
comment = 'Test boolean'
default = True
option = section.add_option(BooleanOption, name, default=default, comment=comment)
option = section.add_option(
BooleanOption, name, default=default, comment=comment)
assert (
option.export_to_config()
== f"""# {comment}
assert option.export_to_config() == f"""# {comment}
# Default: {str(default).lower()}
# {name} =
"""
)
option.set(not default)
assert (
option.export_to_config()
== f"""# {comment}
assert option.export_to_config() == f"""# {comment}
# Default: {str(default).lower()}
{name} = {str(not default).lower()}
"""
)
option.set(default)
assert (
option.export_to_config()
== f"""# {comment}
assert option.export_to_config() == f"""# {comment}
# Default: {str(default).lower()}
# {name} =
"""
)

View file

@ -2,6 +2,7 @@
""" Tests on opening hours helpers """
import pytest
from MySQLdb._exceptions import Error
from mylib.mysql import MyDB
@ -10,9 +11,7 @@ from mylib.mysql import MyDB
class FakeMySQLdbCursor:
""" Fake MySQLdb cursor """
def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception):
self.expected_sql = expected_sql
self.expected_params = expected_params
self.expected_return = expected_return
@ -21,25 +20,13 @@ class FakeMySQLdbCursor:
def execute(self, sql, params=None):
if self.expected_exception:
raise Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
raise Error(f'{self}.execute({sql}, {params}): expected exception')
if self.expected_just_try and not sql.lower().startswith('select '):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode'
# pylint: disable=consider-using-f-string
assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql)
# pylint: disable=consider-using-f-string
assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params)
return self.expected_return
@property
@ -52,14 +39,16 @@ class FakeMySQLdbCursor:
def fetchall(self):
if isinstance(self.expected_return, list):
return (
list(row.values()) if isinstance(row, dict) else row for row in self.expected_return
list(row.values())
if isinstance(row, dict) else row
for row in self.expected_return
)
return self.expected_return
def __repr__(self):
return (
f"FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, "
f"{self.expected_return}, {self.expected_just_try})"
f'FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, '
f'{self.expected_return}, {self.expected_just_try})'
)
@ -74,14 +63,11 @@ class FakeMySQLdb:
just_try = False
def __init__(self, **kwargs):
allowed_kwargs = dict(
db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool
)
allowed_kwargs = dict(db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool)
for arg, value in kwargs.items():
assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"'
assert isinstance(
value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
assert isinstance(value, allowed_kwargs[arg]), \
f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})'
setattr(self, arg, value)
def close(self):
@ -89,11 +75,9 @@ class FakeMySQLdb:
def cursor(self):
return FakeMySQLdbCursor(
self.expected_sql,
self.expected_params,
self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
self.expected_sql, self.expected_params,
self.expected_return, self.expected_just_try or self.just_try,
self.expected_exception
)
def commit(self):
@ -121,19 +105,19 @@ def fake_mysqldb_connect_just_try(**kwargs):
@pytest.fixture
def test_mydb():
return MyDB("127.0.0.1", "user", "password", "dbname")
return MyDB('127.0.0.1', 'user', 'password', 'dbname')
@pytest.fixture
def fake_mydb(mocker):
mocker.patch("MySQLdb.connect", fake_mysqldb_connect)
return MyDB("127.0.0.1", "user", "password", "dbname")
mocker.patch('MySQLdb.connect', fake_mysqldb_connect)
return MyDB('127.0.0.1', 'user', 'password', 'dbname')
@pytest.fixture
def fake_just_try_mydb(mocker):
mocker.patch("MySQLdb.connect", fake_mysqldb_connect_just_try)
return MyDB("127.0.0.1", "user", "password", "dbname", just_try=True)
mocker.patch('MySQLdb.connect', fake_mysqldb_connect_just_try)
return MyDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True)
@pytest.fixture
@ -148,22 +132,13 @@ def fake_connected_just_try_mydb(fake_just_try_mydb):
return fake_just_try_mydb
def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args)
# pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs)
return expected_return
return mock_args
@ -171,22 +146,13 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql)
# pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params)
return expected_return
return mock_doSQL
@ -200,11 +166,15 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter():
assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict(
test1=1, test2=2
)
def test_combine_params_with_kargs():
assert MyDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
assert MyDB._combine_params(dict(test1=1), test2=2) == dict(
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter():
@ -214,40 +184,47 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved():
args = ("test = test", dict(test1=1))
args = ('test = test', dict(test1=1))
assert MyDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw():
assert MyDB._format_where_clauses("test = test") == ("test = test", {})
assert MyDB._format_where_clauses('test = test') == (('test = test'), {})
def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2))
where_clauses = (
'test1 = %(test1)s AND test2 = %(test2)s',
dict(test1=1, test2=2)
)
assert MyDB._format_where_clauses(where_clauses) == where_clauses
def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2)
assert MyDB._format_where_clauses(where_clauses) == (
"`test1` = %(test1)s AND `test2` = %(test2)s",
where_clauses,
'`test1` = %(test1)s AND `test2` = %(test2)s',
where_clauses
)
def test_format_where_clauses_combined_types():
where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4))
where_clauses = (
'test1 = 1',
('test2 LIKE %(test2)s', dict(test2=2)),
dict(test3=3, test4=4)
)
assert MyDB._format_where_clauses(where_clauses) == (
"test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s",
dict(test2=2, test3=3, test4=4),
'test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s',
dict(test2=2, test3=3, test4=4)
)
def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2)
assert MyDB._format_where_clauses(where_clauses, where_op="OR") == (
"`test1` = %(test1)s OR `test2` = %(test2)s",
where_clauses,
assert MyDB._format_where_clauses(where_clauses, where_op='OR') == (
'`test1` = %(test1)s OR `test2` = %(test2)s',
where_clauses
)
@ -255,8 +232,8 @@ def test_add_where_clauses():
sql = "SELECT * FROM table"
where_clauses = dict(test1=1, test2=2)
assert MyDB._add_where_clauses(sql, None, where_clauses) == (
sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s",
where_clauses,
sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s',
where_clauses
)
@ -265,102 +242,106 @@ def test_add_where_clauses_preserved_params():
where_clauses = dict(test1=1, test2=2)
params = dict(fake1=1)
assert MyDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s",
dict(**where_clauses, **params),
sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s',
dict(**where_clauses, **params)
)
def test_add_where_clauses_with_op():
sql = "SELECT * FROM table"
where_clauses = ("test1=1", "test2=2")
assert MyDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + " WHERE test1=1 OR test2=2",
{},
where_clauses = ('test1=1', 'test2=2')
assert MyDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == (
sql + ' WHERE test1=1 OR test2=2',
{}
)
def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=%(test1)s"
params = dict(test1="new_value")
where_clauses = dict(test1="where_value")
params = dict(test1='new_value')
where_clauses = dict(test1='where_value')
assert MyDB._add_where_clauses(sql, params, where_clauses) == (
sql + " WHERE `test1` = %(test1_1)s",
dict(test1="new_value", test1_1="where_value"),
sql + ' WHERE `test1` = %(test1_1)s',
dict(test1='new_value', test1_1='where_value')
)
def test_quote_table_name():
assert MyDB._quote_table_name("mytable") == "`mytable`"
assert MyDB._quote_table_name("myschema.mytable") == "`myschema`.`mytable`"
assert MyDB._quote_table_name("mytable") == '`mytable`'
assert MyDB._quote_table_name("myschema.mytable") == '`myschema`.`mytable`'
def test_insert(mocker, test_mydb):
values = dict(test1=1, test2=2)
mocker.patch(
"mylib.mysql.MyDB.doSQL",
'mylib.mysql.MyDB.doSQL',
generate_mock_doSQL(
"INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)", values
),
'INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)',
values
)
)
assert test_mydb.insert("mytable", values)
assert test_mydb.insert('mytable', values)
def test_insert_just_try(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.insert("mytable", dict(test1=1, test2=2), just_try=True)
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try)
assert test_mydb.insert('mytable', dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_mydb):
values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4)
mocker.patch(
"mylib.mysql.MyDB.doSQL",
'mylib.mysql.MyDB.doSQL',
generate_mock_doSQL(
"UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` ="
" %(test3)s AND `test4` = %(test4)s",
dict(**values, **where_clauses),
),
'UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` = %(test3)s AND `test4` = %(test4)s',
dict(**values, **where_clauses)
)
)
assert test_mydb.update("mytable", values, where_clauses)
assert test_mydb.update('mytable', values, where_clauses)
def test_update_just_try(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try)
assert test_mydb.update('mytable', dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_mydb):
where_clauses = dict(test1=1, test2=2)
mocker.patch(
"mylib.mysql.MyDB.doSQL",
'mylib.mysql.MyDB.doSQL',
generate_mock_doSQL(
"DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s", where_clauses
),
'DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s',
where_clauses
)
)
assert test_mydb.delete("mytable", where_clauses)
assert test_mydb.delete('mytable', where_clauses)
def test_delete_just_try(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.delete("mytable", None, just_try=True)
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try)
assert test_mydb.delete('mytable', None, just_try=True)
def test_truncate(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSQL", generate_mock_doSQL("TRUNCATE TABLE `mytable`", None))
mocker.patch(
'mylib.mysql.MyDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE `mytable`', None)
)
assert test_mydb.truncate("mytable")
assert test_mydb.truncate('mytable')
def test_truncate_just_try(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try)
assert test_mydb.truncate("mytable", just_try=True)
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try)
assert test_mydb.truncate('mytable', just_try=True)
def test_select(mocker, test_mydb):
fields = ("field1", "field2")
fields = ('field1', 'field2')
where_clauses = dict(test3=3, test4=4)
expected_return = [
dict(field1=1, field2=2),
@ -368,28 +349,30 @@ def test_select(mocker, test_mydb):
]
order_by = "field1, DESC"
mocker.patch(
"mylib.mysql.MyDB.doSelect",
'mylib.mysql.MyDB.doSelect',
generate_mock_doSQL(
"SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` ="
" %(test4)s ORDER BY " + order_by,
where_clauses,
expected_return,
),
'SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` = %(test4)s ORDER BY ' + order_by,
where_clauses, expected_return
)
)
assert test_mydb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
assert test_mydb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSelect", generate_mock_doSQL("SELECT * FROM `mytable`"))
mocker.patch(
'mylib.mysql.MyDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM `mytable`'
)
)
assert test_mydb.select("mytable")
assert test_mydb.select('mytable')
def test_select_just_try(mocker, test_mydb):
mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try)
assert test_mydb.select("mytable", None, None, just_try=True)
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try)
assert test_mydb.select('mytable', None, None, just_try=True)
#
# Tests on main methods
@ -406,7 +389,12 @@ def test_connect(mocker, test_mydb):
use_unicode=True,
)
mocker.patch("MySQLdb.connect", generate_mock_args(expected_kwargs=expected_kwargs))
mocker.patch(
'MySQLdb.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_mydb.connect()
@ -420,61 +408,48 @@ def test_close_connected(fake_connected_mydb):
def test_doSQL(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s"
fake_connected_mydb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s'
fake_connected_mydb._conn.expected_params = dict(test1=1)
fake_connected_mydb.doSQL(
fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params
)
fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params)
def test_doSQL_without_params(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = "DELETE FROM table"
fake_connected_mydb._conn.expected_sql = 'DELETE FROM table'
fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_mydb):
assert fake_connected_just_try_mydb.doSQL("DELETE FROM table")
assert fake_connected_just_try_mydb.doSQL('DELETE FROM table')
def test_doSQL_on_exception(fake_connected_mydb):
fake_connected_mydb._conn.expected_exception = True
assert fake_connected_mydb.doSQL("DELETE FROM table") is False
assert fake_connected_mydb.doSQL('DELETE FROM table') is False
def test_doSelect(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s'
fake_connected_mydb._conn.expected_params = dict(test1=1)
fake_connected_mydb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_mydb.doSelect(
fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params
)
== fake_connected_mydb._conn.expected_return
)
assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) == fake_connected_mydb._conn.expected_return
def test_doSelect_without_params(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = "SELECT * FROM table"
fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table'
fake_connected_mydb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql)
== fake_connected_mydb._conn.expected_return
)
assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql) == fake_connected_mydb._conn.expected_return
def test_doSelect_on_exception(fake_connected_mydb):
fake_connected_mydb._conn.expected_exception = True
assert fake_connected_mydb.doSelect("SELECT * FROM table") is False
assert fake_connected_mydb.doSelect('SELECT * FROM table') is False
def test_doSelect_just_try(fake_connected_just_try_mydb):
fake_connected_just_try_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_just_try_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s'
fake_connected_just_try_mydb._conn.expected_params = dict(test1=1)
fake_connected_just_try_mydb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_just_try_mydb.doSelect(
assert fake_connected_just_try_mydb.doSelect(
fake_connected_just_try_mydb._conn.expected_sql,
fake_connected_just_try_mydb._conn.expected_params,
)
== fake_connected_just_try_mydb._conn.expected_return
)
fake_connected_just_try_mydb._conn.expected_params
) == fake_connected_just_try_mydb._conn.expected_return

View file

@ -2,7 +2,6 @@
""" Tests on opening hours helpers """
import datetime
import pytest
from mylib import opening_hours
@ -13,16 +12,14 @@ from mylib import opening_hours
def test_parse_exceptional_closures_one_day_without_time_period():
assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [
{"days": [datetime.date(2017, 9, 22)], "hours_periods": []}
]
assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [{'days': [datetime.date(2017, 9, 22)], 'hours_periods': []}]
def test_parse_exceptional_closures_one_day_with_time_period():
assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30"]) == [
{
"days": [datetime.date(2017, 11, 26)],
"hours_periods": [{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}],
'days': [datetime.date(2017, 11, 26)],
'hours_periods': [{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}]
}
]
@ -30,11 +27,11 @@ def test_parse_exceptional_closures_one_day_with_time_period():
def test_parse_exceptional_closures_one_day_with_multiple_time_periods():
assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30 14h-18h"]) == [
{
"days": [datetime.date(2017, 11, 26)],
"hours_periods": [
{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
],
'days': [datetime.date(2017, 11, 26)],
'hours_periods': [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)},
]
}
]
@ -42,12 +39,8 @@ def test_parse_exceptional_closures_one_day_with_multiple_time_periods():
def test_parse_exceptional_closures_full_days_period():
assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017"]) == [
{
"days": [
datetime.date(2017, 9, 20),
datetime.date(2017, 9, 21),
datetime.date(2017, 9, 22),
],
"hours_periods": [],
'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)],
'hours_periods': []
}
]
@ -60,12 +53,8 @@ def test_parse_exceptional_closures_invalid_days_period():
def test_parse_exceptional_closures_days_period_with_time_period():
assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017 9h-12h"]) == [
{
"days": [
datetime.date(2017, 9, 20),
datetime.date(2017, 9, 21),
datetime.date(2017, 9, 22),
],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}],
'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)],
'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}]
}
]
@ -81,38 +70,31 @@ def test_parse_exceptional_closures_invalid_time_period():
def test_parse_exceptional_closures_multiple_periods():
assert opening_hours.parse_exceptional_closures(
["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]
) == [
assert opening_hours.parse_exceptional_closures(["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]) == [
{
"days": [
'days': [
datetime.date(2017, 9, 20),
datetime.date(2017, 11, 25),
datetime.date(2017, 11, 26),
],
"hours_periods": [
{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
],
'hours_periods': [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)},
]
}
]
#
# Tests on parse_normal_opening_hours()
#
def test_parse_normal_opening_hours_one_day():
assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [
{"days": ["jeudi"], "hours_periods": []}
]
assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [{'days': ["jeudi"], 'hours_periods': []}]
def test_parse_normal_opening_hours_multiple_days():
assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [
{"days": ["lundi", "jeudi"], "hours_periods": []}
]
assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [{'days': ["lundi", "jeudi"], 'hours_periods': []}]
def test_parse_normal_opening_hours_invalid_day():
@ -122,17 +104,13 @@ def test_parse_normal_opening_hours_invalid_day():
def test_parse_normal_opening_hours_one_days_period():
assert opening_hours.parse_normal_opening_hours(["lundi-jeudi"]) == [
{"days": ["lundi", "mardi", "mercredi", "jeudi"], "hours_periods": []}
{'days': ["lundi", "mardi", "mercredi", "jeudi"], 'hours_periods': []}
]
def test_parse_normal_opening_hours_one_day_with_one_time_period():
assert opening_hours.parse_normal_opening_hours(["jeudi 9h-12h"]) == [
{
"days": ["jeudi"],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}],
}
]
{'days': ["jeudi"], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}]}]
def test_parse_normal_opening_hours_invalid_days_period():
@ -144,10 +122,7 @@ def test_parse_normal_opening_hours_invalid_days_period():
def test_parse_normal_opening_hours_one_time_period():
assert opening_hours.parse_normal_opening_hours(["9h-18h30"]) == [
{
"days": [],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(18, 30)}],
}
{'days': [], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(18, 30)}]}
]
@ -157,60 +132,48 @@ def test_parse_normal_opening_hours_invalid_time_period():
def test_parse_normal_opening_hours_multiple_periods():
assert opening_hours.parse_normal_opening_hours(
["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]
) == [
assert opening_hours.parse_normal_opening_hours(["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]) == [
{
"days": ["lundi", "mardi", "mercredi", "jeudi", "vendredi"],
"hours_periods": [
{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
],
'days': ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi'],
'hours_periods': [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)},
]
},
{
"days": ["samedi"],
"hours_periods": [
{"start": datetime.time(9, 30), "stop": datetime.time(18, 0)},
],
'days': ['samedi'],
'hours_periods': [
{'start': datetime.time(9, 30), 'stop': datetime.time(18, 0)},
]
},
{
"days": ["dimanche"],
"hours_periods": [
{"start": datetime.time(9, 30), "stop": datetime.time(12, 0)},
],
'days': ['dimanche'],
'hours_periods': [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 0)},
]
},
]
#
# Tests on is_closed
#
exceptional_closures = [
"22/09/2017",
"20/09/2017-22/09/2017",
"20/09/2017-22/09/2017 18/09/2017",
"25/11/2017",
"26/11/2017 9h30-12h30",
]
normal_opening_hours = [
"lundi-mardi jeudi 9h30-12h30 14h-16h30",
"mercredi vendredi 9h30-12h30 14h-17h",
]
exceptional_closures = ["22/09/2017", "20/09/2017-22/09/2017", "20/09/2017-22/09/2017 18/09/2017", "25/11/2017", "26/11/2017 9h30-12h30"]
normal_opening_hours = ["lundi-mardi jeudi 9h30-12h30 14h-16h30", "mercredi vendredi 9h30-12h30 14h-17h"]
nonworking_public_holidays = [
"1janvier",
"paques",
"lundi_paques",
"1mai",
"8mai",
"jeudi_ascension",
"lundi_pentecote",
"14juillet",
"15aout",
"1novembre",
"11novembre",
"noel",
'1janvier',
'paques',
'lundi_paques',
'1mai',
'8mai',
'jeudi_ascension',
'lundi_pentecote',
'14juillet',
'15aout',
'1novembre',
'11novembre',
'noel',
]
@ -219,8 +182,12 @@ def test_is_closed_when_normaly_closed_by_hour():
normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 1, 20, 15),
) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
when=datetime.datetime(2017, 5, 1, 20, 15)
) == {
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_on_exceptional_closure_full_day():
@ -228,8 +195,12 @@ def test_is_closed_on_exceptional_closure_full_day():
normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 9, 22, 14, 15),
) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": True}
when=datetime.datetime(2017, 9, 22, 14, 15)
) == {
'closed': True,
'exceptional_closure': True,
'exceptional_closure_all_day': True
}
def test_is_closed_on_exceptional_closure_day():
@ -237,8 +208,12 @@ def test_is_closed_on_exceptional_closure_day():
normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 11, 26, 10, 30),
) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": False}
when=datetime.datetime(2017, 11, 26, 10, 30)
) == {
'closed': True,
'exceptional_closure': True,
'exceptional_closure_all_day': False
}
def test_is_closed_on_nonworking_public_holidays():
@ -246,8 +221,12 @@ def test_is_closed_on_nonworking_public_holidays():
normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 1, 1, 10, 30),
) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
when=datetime.datetime(2017, 1, 1, 10, 30)
) == {
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_when_normaly_closed_by_day():
@ -255,8 +234,12 @@ def test_is_closed_when_normaly_closed_by_day():
normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 6, 14, 15),
) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
when=datetime.datetime(2017, 5, 6, 14, 15)
) == {
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_when_normaly_opened():
@ -264,8 +247,12 @@ def test_is_closed_when_normaly_opened():
normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 2, 15, 15),
) == {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False}
when=datetime.datetime(2017, 5, 2, 15, 15)
) == {
'closed': False,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_easter_date():
@ -285,18 +272,18 @@ def test_easter_date():
def test_nonworking_french_public_days_of_the_year():
assert opening_hours.nonworking_french_public_days_of_the_year(2021) == {
"1janvier": datetime.date(2021, 1, 1),
"paques": datetime.date(2021, 4, 4),
"lundi_paques": datetime.date(2021, 4, 5),
"1mai": datetime.date(2021, 5, 1),
"8mai": datetime.date(2021, 5, 8),
"jeudi_ascension": datetime.date(2021, 5, 13),
"pentecote": datetime.date(2021, 5, 23),
"lundi_pentecote": datetime.date(2021, 5, 24),
"14juillet": datetime.date(2021, 7, 14),
"15aout": datetime.date(2021, 8, 15),
"1novembre": datetime.date(2021, 11, 1),
"11novembre": datetime.date(2021, 11, 11),
"noel": datetime.date(2021, 12, 25),
"saint_etienne": datetime.date(2021, 12, 26),
'1janvier': datetime.date(2021, 1, 1),
'paques': datetime.date(2021, 4, 4),
'lundi_paques': datetime.date(2021, 4, 5),
'1mai': datetime.date(2021, 5, 1),
'8mai': datetime.date(2021, 5, 8),
'jeudi_ascension': datetime.date(2021, 5, 13),
'pentecote': datetime.date(2021, 5, 23),
'lundi_pentecote': datetime.date(2021, 5, 24),
'14juillet': datetime.date(2021, 7, 14),
'15aout': datetime.date(2021, 8, 15),
'1novembre': datetime.date(2021, 11, 1),
'11novembre': datetime.date(2021, 11, 11),
'noel': datetime.date(2021, 12, 25),
'saint_etienne': datetime.date(2021, 12, 26)
}

View file

@ -10,9 +10,7 @@ from mylib.oracle import OracleDB
class FakeCXOracleCursor:
""" Fake cx_Oracle cursor """
def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception):
self.expected_sql = expected_sql
self.expected_params = expected_params
self.expected_return = expected_return
@ -23,25 +21,13 @@ class FakeCXOracleCursor:
def execute(self, sql, **params):
assert self.opened
if self.expected_exception:
raise cx_Oracle.Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
raise cx_Oracle.Error(f'{self}.execute({sql}, {params}): expected exception')
if self.expected_just_try and not sql.lower().startswith('select '):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode'
# pylint: disable=consider-using-f-string
assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql)
# pylint: disable=consider-using-f-string
assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params)
return self.expected_return
def fetchall(self):
@ -57,8 +43,8 @@ class FakeCXOracleCursor:
def __repr__(self):
return (
f"FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, "
f"{self.expected_return}, {self.expected_just_try})"
f'FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, '
f'{self.expected_return}, {self.expected_just_try})'
)
@ -76,9 +62,7 @@ class FakeCXOracle:
allowed_kwargs = dict(dsn=str, user=str, password=(str, None))
for arg, value in kwargs.items():
assert arg in allowed_kwargs, f"Invalid arg {arg}='{value}'"
assert isinstance(
value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
assert isinstance(value, allowed_kwargs[arg]), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value)
def close(self):
@ -86,11 +70,9 @@ class FakeCXOracle:
def cursor(self):
return FakeCXOracleCursor(
self.expected_sql,
self.expected_params,
self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
self.expected_sql, self.expected_params,
self.expected_return, self.expected_just_try or self.just_try,
self.expected_exception
)
def commit(self):
@ -118,19 +100,19 @@ def fake_cxoracle_connect_just_try(**kwargs):
@pytest.fixture
def test_oracledb():
return OracleDB("127.0.0.1/dbname", "user", "password")
return OracleDB('127.0.0.1/dbname', 'user', 'password')
@pytest.fixture
def fake_oracledb(mocker):
mocker.patch("cx_Oracle.connect", fake_cxoracle_connect)
return OracleDB("127.0.0.1/dbname", "user", "password")
mocker.patch('cx_Oracle.connect', fake_cxoracle_connect)
return OracleDB('127.0.0.1/dbname', 'user', 'password')
@pytest.fixture
def fake_just_try_oracledb(mocker):
mocker.patch("cx_Oracle.connect", fake_cxoracle_connect_just_try)
return OracleDB("127.0.0.1/dbname", "user", "password", just_try=True)
mocker.patch('cx_Oracle.connect', fake_cxoracle_connect_just_try)
return OracleDB('127.0.0.1/dbname', 'user', 'password', just_try=True)
@pytest.fixture
@ -145,22 +127,13 @@ def fake_connected_just_try_oracledb(fake_just_try_oracledb):
return fake_just_try_oracledb
def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args)
# pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs)
return expected_return
return mock_args
@ -168,22 +141,13 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql)
# pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params)
return expected_return
return mock_doSQL
@ -197,11 +161,15 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter():
assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict(
test1=1, test2=2
)
def test_combine_params_with_kargs():
assert OracleDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
assert OracleDB._combine_params(dict(test1=1), test2=2) == dict(
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter():
@ -211,16 +179,19 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved():
args = ("test = test", dict(test1=1))
args = ('test = test', dict(test1=1))
assert OracleDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw():
assert OracleDB._format_where_clauses("test = test") == ("test = test", {})
assert OracleDB._format_where_clauses('test = test') == (('test = test'), {})
def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ("test1 = :test1 AND test2 = :test2", dict(test1=1, test2=2))
where_clauses = (
'test1 = :test1 AND test2 = :test2',
dict(test1=1, test2=2)
)
assert OracleDB._format_where_clauses(where_clauses) == where_clauses
@ -228,23 +199,27 @@ def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2)
assert OracleDB._format_where_clauses(where_clauses) == (
'"test1" = :test1 AND "test2" = :test2',
where_clauses,
where_clauses
)
def test_format_where_clauses_combined_types():
where_clauses = ("test1 = 1", ("test2 LIKE :test2", dict(test2=2)), dict(test3=3, test4=4))
where_clauses = (
'test1 = 1',
('test2 LIKE :test2', dict(test2=2)),
dict(test3=3, test4=4)
)
assert OracleDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE :test2 AND "test3" = :test3 AND "test4" = :test4',
dict(test2=2, test3=3, test4=4),
dict(test2=2, test3=3, test4=4)
)
def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2)
assert OracleDB._format_where_clauses(where_clauses, where_op="OR") == (
assert OracleDB._format_where_clauses(where_clauses, where_op='OR') == (
'"test1" = :test1 OR "test2" = :test2',
where_clauses,
where_clauses
)
@ -253,7 +228,7 @@ def test_add_where_clauses():
where_clauses = dict(test1=1, test2=2)
assert OracleDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE "test1" = :test1 AND "test2" = :test2',
where_clauses,
where_clauses
)
@ -263,26 +238,26 @@ def test_add_where_clauses_preserved_params():
params = dict(fake1=1)
assert OracleDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE "test1" = :test1 AND "test2" = :test2',
dict(**where_clauses, **params),
dict(**where_clauses, **params)
)
def test_add_where_clauses_with_op():
sql = "SELECT * FROM table"
where_clauses = ("test1=1", "test2=2")
assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + " WHERE test1=1 OR test2=2",
{},
where_clauses = ('test1=1', 'test2=2')
assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == (
sql + ' WHERE test1=1 OR test2=2',
{}
)
def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=:test1"
params = dict(test1="new_value")
where_clauses = dict(test1="where_value")
params = dict(test1='new_value')
where_clauses = dict(test1='where_value')
assert OracleDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE "test1" = :test1_1',
dict(test1="new_value", test1_1="where_value"),
dict(test1='new_value', test1_1='where_value')
)
@ -294,72 +269,74 @@ def test_quote_table_name():
def test_insert(mocker, test_oracledb):
values = dict(test1=1, test2=2)
mocker.patch(
"mylib.oracle.OracleDB.doSQL",
'mylib.oracle.OracleDB.doSQL',
generate_mock_doSQL(
'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', values
),
'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)',
values
)
)
assert test_oracledb.insert("mytable", values)
assert test_oracledb.insert('mytable', values)
def test_insert_just_try(mocker, test_oracledb):
mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.insert("mytable", dict(test1=1, test2=2), just_try=True)
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try)
assert test_oracledb.insert('mytable', dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_oracledb):
values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4)
mocker.patch(
"mylib.oracle.OracleDB.doSQL",
'mylib.oracle.OracleDB.doSQL',
generate_mock_doSQL(
'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND'
' "test4" = :test4',
dict(**values, **where_clauses),
),
'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND "test4" = :test4',
dict(**values, **where_clauses)
)
)
assert test_oracledb.update("mytable", values, where_clauses)
assert test_oracledb.update('mytable', values, where_clauses)
def test_update_just_try(mocker, test_oracledb):
mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try)
assert test_oracledb.update('mytable', dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_oracledb):
where_clauses = dict(test1=1, test2=2)
mocker.patch(
"mylib.oracle.OracleDB.doSQL",
'mylib.oracle.OracleDB.doSQL',
generate_mock_doSQL(
'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', where_clauses
),
'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2',
where_clauses
)
)
assert test_oracledb.delete("mytable", where_clauses)
assert test_oracledb.delete('mytable', where_clauses)
def test_delete_just_try(mocker, test_oracledb):
mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.delete("mytable", None, just_try=True)
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try)
assert test_oracledb.delete('mytable', None, just_try=True)
def test_truncate(mocker, test_oracledb):
mocker.patch(
"mylib.oracle.OracleDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
'mylib.oracle.OracleDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
)
assert test_oracledb.truncate("mytable")
assert test_oracledb.truncate('mytable')
def test_truncate_just_try(mocker, test_oracledb):
mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try)
assert test_oracledb.truncate("mytable", just_try=True)
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try)
assert test_oracledb.truncate('mytable', just_try=True)
def test_select(mocker, test_oracledb):
fields = ("field1", "field2")
fields = ('field1', 'field2')
where_clauses = dict(test3=3, test4=4)
expected_return = [
dict(field1=1, field2=2),
@ -367,30 +344,30 @@ def test_select(mocker, test_oracledb):
]
order_by = "field1, DESC"
mocker.patch(
"mylib.oracle.OracleDB.doSelect",
'mylib.oracle.OracleDB.doSelect',
generate_mock_doSQL(
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4'
" ORDER BY " + order_by,
where_clauses,
expected_return,
),
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4 ORDER BY ' + order_by,
where_clauses, expected_return
)
)
assert (
test_oracledb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
)
assert test_oracledb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_oracledb):
mocker.patch("mylib.oracle.OracleDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"'))
mocker.patch(
'mylib.oracle.OracleDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM "mytable"'
)
)
assert test_oracledb.select("mytable")
assert test_oracledb.select('mytable')
def test_select_just_try(mocker, test_oracledb):
mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try)
assert test_oracledb.select("mytable", None, None, just_try=True)
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try)
assert test_oracledb.select('mytable', None, None, just_try=True)
#
# Tests on main methods
@ -399,10 +376,17 @@ def test_select_just_try(mocker, test_oracledb):
def test_connect(mocker, test_oracledb):
expected_kwargs = dict(
dsn=test_oracledb._dsn, user=test_oracledb._user, password=test_oracledb._pwd
dsn=test_oracledb._dsn,
user=test_oracledb._user,
password=test_oracledb._pwd
)
mocker.patch("cx_Oracle.connect", generate_mock_args(expected_kwargs=expected_kwargs))
mocker.patch(
'cx_Oracle.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_oracledb.connect()
@ -416,62 +400,50 @@ def test_close_connected(fake_connected_oracledb):
def test_doSQL(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = "DELETE FROM table WHERE test1 = :test1"
fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table WHERE test1 = :test1'
fake_connected_oracledb._conn.expected_params = dict(test1=1)
fake_connected_oracledb.doSQL(
fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params
)
fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params)
def test_doSQL_without_params(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = "DELETE FROM table"
fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table'
fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_oracledb):
assert fake_connected_just_try_oracledb.doSQL("DELETE FROM table")
assert fake_connected_just_try_oracledb.doSQL('DELETE FROM table')
def test_doSQL_on_exception(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_exception = True
assert fake_connected_oracledb.doSQL("DELETE FROM table") is False
assert fake_connected_oracledb.doSQL('DELETE FROM table') is False
def test_doSelect(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1"
fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1'
fake_connected_oracledb._conn.expected_params = dict(test1=1)
fake_connected_oracledb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_oracledb.doSelect(
assert fake_connected_oracledb.doSelect(
fake_connected_oracledb._conn.expected_sql,
fake_connected_oracledb._conn.expected_params,
)
== fake_connected_oracledb._conn.expected_return
)
fake_connected_oracledb._conn.expected_params) == fake_connected_oracledb._conn.expected_return
def test_doSelect_without_params(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table"
fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table'
fake_connected_oracledb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql)
== fake_connected_oracledb._conn.expected_return
)
assert fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql) == fake_connected_oracledb._conn.expected_return
def test_doSelect_on_exception(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_exception = True
assert fake_connected_oracledb.doSelect("SELECT * FROM table") is False
assert fake_connected_oracledb.doSelect('SELECT * FROM table') is False
def test_doSelect_just_try(fake_connected_just_try_oracledb):
fake_connected_just_try_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1"
fake_connected_just_try_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1'
fake_connected_just_try_oracledb._conn.expected_params = dict(test1=1)
fake_connected_just_try_oracledb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_just_try_oracledb.doSelect(
assert fake_connected_just_try_oracledb.doSelect(
fake_connected_just_try_oracledb._conn.expected_sql,
fake_connected_just_try_oracledb._conn.expected_params,
)
== fake_connected_just_try_oracledb._conn.expected_return
)
fake_connected_just_try_oracledb._conn.expected_params
) == fake_connected_just_try_oracledb._conn.expected_return

View file

@ -10,9 +10,7 @@ from mylib.pgsql import PgDB
class FakePsycopg2Cursor:
""" Fake Psycopg2 cursor """
def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception):
self.expected_sql = expected_sql
self.expected_params = expected_params
self.expected_return = expected_return
@ -21,25 +19,13 @@ class FakePsycopg2Cursor:
def execute(self, sql, params=None):
if self.expected_exception:
raise psycopg2.Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
raise psycopg2.Error(f'{self}.execute({sql}, {params}): expected exception')
if self.expected_just_try and not sql.lower().startswith('select '):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode'
# pylint: disable=consider-using-f-string
assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql)
# pylint: disable=consider-using-f-string
assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params)
return self.expected_return
def fetchall(self):
@ -47,8 +33,8 @@ class FakePsycopg2Cursor:
def __repr__(self):
return (
f"FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, "
f"{self.expected_return}, {self.expected_just_try})"
f'FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, '
f'{self.expected_return}, {self.expected_just_try})'
)
@ -66,9 +52,8 @@ class FakePsycopg2:
allowed_kwargs = dict(dbname=str, user=str, password=(str, None), host=str)
for arg, value in kwargs.items():
assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"'
assert isinstance(
value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
assert isinstance(value, allowed_kwargs[arg]), \
f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})'
setattr(self, arg, value)
def close(self):
@ -78,16 +63,14 @@ class FakePsycopg2:
self._check_just_try()
assert len(arg) == 1 and isinstance(arg[0], str)
if self.expected_exception:
raise psycopg2.Error(f"set_client_encoding({arg[0]}): Expected exception")
raise psycopg2.Error(f'set_client_encoding({arg[0]}): Expected exception')
return self.expected_return
def cursor(self):
return FakePsycopg2Cursor(
self.expected_sql,
self.expected_params,
self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
self.expected_sql, self.expected_params,
self.expected_return, self.expected_just_try or self.just_try,
self.expected_exception
)
def commit(self):
@ -115,19 +98,19 @@ def fake_psycopg2_connect_just_try(**kwargs):
@pytest.fixture
def test_pgdb():
return PgDB("127.0.0.1", "user", "password", "dbname")
return PgDB('127.0.0.1', 'user', 'password', 'dbname')
@pytest.fixture
def fake_pgdb(mocker):
mocker.patch("psycopg2.connect", fake_psycopg2_connect)
return PgDB("127.0.0.1", "user", "password", "dbname")
mocker.patch('psycopg2.connect', fake_psycopg2_connect)
return PgDB('127.0.0.1', 'user', 'password', 'dbname')
@pytest.fixture
def fake_just_try_pgdb(mocker):
mocker.patch("psycopg2.connect", fake_psycopg2_connect_just_try)
return PgDB("127.0.0.1", "user", "password", "dbname", just_try=True)
mocker.patch('psycopg2.connect', fake_psycopg2_connect_just_try)
return PgDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True)
@pytest.fixture
@ -142,22 +125,13 @@ def fake_connected_just_try_pgdb(fake_just_try_pgdb):
return fake_just_try_pgdb
def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args)
# pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs)
return expected_return
return mock_args
@ -165,22 +139,13 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql)
# pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params)
return expected_return
return mock_doSQL
@ -194,11 +159,15 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter():
assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict(
test1=1, test2=2
)
def test_combine_params_with_kargs():
assert PgDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
assert PgDB._combine_params(dict(test1=1), test2=2) == dict(
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter():
@ -208,16 +177,19 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved():
args = ("test = test", dict(test1=1))
args = ('test = test', dict(test1=1))
assert PgDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw():
assert PgDB._format_where_clauses("test = test") == ("test = test", {})
assert PgDB._format_where_clauses('test = test') == (('test = test'), {})
def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2))
where_clauses = (
'test1 = %(test1)s AND test2 = %(test2)s',
dict(test1=1, test2=2)
)
assert PgDB._format_where_clauses(where_clauses) == where_clauses
@ -225,23 +197,27 @@ def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2)
assert PgDB._format_where_clauses(where_clauses) == (
'"test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses,
where_clauses
)
def test_format_where_clauses_combined_types():
where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4))
where_clauses = (
'test1 = 1',
('test2 LIKE %(test2)s', dict(test2=2)),
dict(test3=3, test4=4)
)
assert PgDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE %(test2)s AND "test3" = %(test3)s AND "test4" = %(test4)s',
dict(test2=2, test3=3, test4=4),
dict(test2=2, test3=3, test4=4)
)
def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2)
assert PgDB._format_where_clauses(where_clauses, where_op="OR") == (
assert PgDB._format_where_clauses(where_clauses, where_op='OR') == (
'"test1" = %(test1)s OR "test2" = %(test2)s',
where_clauses,
where_clauses
)
@ -250,7 +226,7 @@ def test_add_where_clauses():
where_clauses = dict(test1=1, test2=2)
assert PgDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses,
where_clauses
)
@ -260,26 +236,26 @@ def test_add_where_clauses_preserved_params():
params = dict(fake1=1)
assert PgDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
dict(**where_clauses, **params),
dict(**where_clauses, **params)
)
def test_add_where_clauses_with_op():
sql = "SELECT * FROM table"
where_clauses = ("test1=1", "test2=2")
assert PgDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + " WHERE test1=1 OR test2=2",
{},
where_clauses = ('test1=1', 'test2=2')
assert PgDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == (
sql + ' WHERE test1=1 OR test2=2',
{}
)
def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=%(test1)s"
params = dict(test1="new_value")
where_clauses = dict(test1="where_value")
params = dict(test1='new_value')
where_clauses = dict(test1='where_value')
assert PgDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE "test1" = %(test1_1)s',
dict(test1="new_value", test1_1="where_value"),
dict(test1='new_value', test1_1='where_value')
)
@ -291,70 +267,74 @@ def test_quote_table_name():
def test_insert(mocker, test_pgdb):
values = dict(test1=1, test2=2)
mocker.patch(
"mylib.pgsql.PgDB.doSQL",
'mylib.pgsql.PgDB.doSQL',
generate_mock_doSQL(
'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', values
),
'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)',
values
)
)
assert test_pgdb.insert("mytable", values)
assert test_pgdb.insert('mytable', values)
def test_insert_just_try(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.insert("mytable", dict(test1=1, test2=2), just_try=True)
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try)
assert test_pgdb.insert('mytable', dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_pgdb):
values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4)
mocker.patch(
"mylib.pgsql.PgDB.doSQL",
'mylib.pgsql.PgDB.doSQL',
generate_mock_doSQL(
'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" ='
' %(test3)s AND "test4" = %(test4)s',
dict(**values, **where_clauses),
),
'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" = %(test3)s AND "test4" = %(test4)s',
dict(**values, **where_clauses)
)
)
assert test_pgdb.update("mytable", values, where_clauses)
assert test_pgdb.update('mytable', values, where_clauses)
def test_update_just_try(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try)
assert test_pgdb.update('mytable', dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_pgdb):
where_clauses = dict(test1=1, test2=2)
mocker.patch(
"mylib.pgsql.PgDB.doSQL",
'mylib.pgsql.PgDB.doSQL',
generate_mock_doSQL(
'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', where_clauses
),
'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses
)
)
assert test_pgdb.delete("mytable", where_clauses)
assert test_pgdb.delete('mytable', where_clauses)
def test_delete_just_try(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.delete("mytable", None, just_try=True)
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try)
assert test_pgdb.delete('mytable', None, just_try=True)
def test_truncate(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None))
mocker.patch(
'mylib.pgsql.PgDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
)
assert test_pgdb.truncate("mytable")
assert test_pgdb.truncate('mytable')
def test_truncate_just_try(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try)
assert test_pgdb.truncate("mytable", just_try=True)
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try)
assert test_pgdb.truncate('mytable', just_try=True)
def test_select(mocker, test_pgdb):
fields = ("field1", "field2")
fields = ('field1', 'field2')
where_clauses = dict(test3=3, test4=4)
expected_return = [
dict(field1=1, field2=2),
@ -362,28 +342,30 @@ def test_select(mocker, test_pgdb):
]
order_by = "field1, DESC"
mocker.patch(
"mylib.pgsql.PgDB.doSelect",
'mylib.pgsql.PgDB.doSelect',
generate_mock_doSQL(
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" ='
" %(test4)s ORDER BY " + order_by,
where_clauses,
expected_return,
),
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" = %(test4)s ORDER BY ' + order_by,
where_clauses, expected_return
)
)
assert test_pgdb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
assert test_pgdb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"'))
mocker.patch(
'mylib.pgsql.PgDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM "mytable"'
)
)
assert test_pgdb.select("mytable")
assert test_pgdb.select('mytable')
def test_select_just_try(mocker, test_pgdb):
mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try)
assert test_pgdb.select("mytable", None, None, just_try=True)
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try)
assert test_pgdb.select('mytable', None, None, just_try=True)
#
# Tests on main methods
@ -392,10 +374,18 @@ def test_select_just_try(mocker, test_pgdb):
def test_connect(mocker, test_pgdb):
expected_kwargs = dict(
dbname=test_pgdb._db, user=test_pgdb._user, host=test_pgdb._host, password=test_pgdb._pwd
dbname=test_pgdb._db,
user=test_pgdb._user,
host=test_pgdb._host,
password=test_pgdb._pwd
)
mocker.patch("psycopg2.connect", generate_mock_args(expected_kwargs=expected_kwargs))
mocker.patch(
'psycopg2.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_pgdb.connect()
@ -409,74 +399,61 @@ def test_close_connected(fake_connected_pgdb):
def test_setEncoding(fake_connected_pgdb):
assert fake_connected_pgdb.setEncoding("utf8")
assert fake_connected_pgdb.setEncoding('utf8')
def test_setEncoding_not_connected(fake_pgdb):
assert fake_pgdb.setEncoding("utf8") is False
assert fake_pgdb.setEncoding('utf8') is False
def test_setEncoding_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.setEncoding("utf8") is False
assert fake_connected_pgdb.setEncoding('utf8') is False
def test_doSQL(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s"
fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s'
fake_connected_pgdb._conn.expected_params = dict(test1=1)
fake_connected_pgdb.doSQL(
fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params
)
fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params)
def test_doSQL_without_params(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = "DELETE FROM table"
fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table'
fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_pgdb):
assert fake_connected_just_try_pgdb.doSQL("DELETE FROM table")
assert fake_connected_just_try_pgdb.doSQL('DELETE FROM table')
def test_doSQL_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.doSQL("DELETE FROM table") is False
assert fake_connected_pgdb.doSQL('DELETE FROM table') is False
def test_doSelect(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s'
fake_connected_pgdb._conn.expected_params = dict(test1=1)
fake_connected_pgdb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_pgdb.doSelect(
fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params
)
== fake_connected_pgdb._conn.expected_return
)
assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) == fake_connected_pgdb._conn.expected_return
def test_doSelect_without_params(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table"
fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table'
fake_connected_pgdb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql)
== fake_connected_pgdb._conn.expected_return
)
assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql) == fake_connected_pgdb._conn.expected_return
def test_doSelect_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.doSelect("SELECT * FROM table") is False
assert fake_connected_pgdb.doSelect('SELECT * FROM table') is False
def test_doSelect_just_try(fake_connected_just_try_pgdb):
fake_connected_just_try_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_just_try_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s'
fake_connected_just_try_pgdb._conn.expected_params = dict(test1=1)
fake_connected_just_try_pgdb._conn.expected_return = [dict(test1=1)]
assert (
fake_connected_just_try_pgdb.doSelect(
assert fake_connected_just_try_pgdb.doSelect(
fake_connected_just_try_pgdb._conn.expected_sql,
fake_connected_just_try_pgdb._conn.expected_params,
)
== fake_connected_just_try_pgdb._conn.expected_return
)
fake_connected_just_try_pgdb._conn.expected_params
) == fake_connected_just_try_pgdb._conn.expected_return

View file

@ -3,14 +3,13 @@
import datetime
import os
import pytest
from mylib.telltale import TelltaleFile
def test_create_telltale_file(tmp_path):
filename = "test"
filename = 'test'
file = TelltaleFile(filename=filename, dirpath=tmp_path)
assert file.filename == filename
assert file.dirpath == tmp_path
@ -25,15 +24,15 @@ def test_create_telltale_file(tmp_path):
def test_create_telltale_file_with_filepath_and_invalid_dirpath():
with pytest.raises(AssertionError):
TelltaleFile(filepath="/tmp/test", dirpath="/var/tmp")
TelltaleFile(filepath='/tmp/test', dirpath='/var/tmp')
def test_create_telltale_file_with_filepath_and_invalid_filename():
with pytest.raises(AssertionError):
TelltaleFile(filepath="/tmp/test", filename="other")
TelltaleFile(filepath='/tmp/test', filename='other')
def test_remove_telltale_file(tmp_path):
file = TelltaleFile(filename="test", dirpath=tmp_path)
file = TelltaleFile(filename='test', dirpath=tmp_path)
file.update()
assert file.remove()