Introduce pyupgrade,isort,black and configure pre-commit hooks to run all testing tools before commit
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Benjamin Renard 2023-01-16 12:56:12 +01:00
parent a83c3d635f
commit 62c3fadf96
34 changed files with 2356 additions and 2026 deletions

39
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,39 @@
# Pre-commit hooks to run tests and ensure code is cleaned.
# See https://pre-commit.com for more information
repos:
- repo: local
hooks:
- id: pytest
name: pytest
entry: python3 -m pytest tests
language: system
pass_filenames: false
always_run: true
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint --extension-pkg-whitelist=cx_Oracle
language: system
types: [python]
require_serial: true
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ['--max-line-length=100']
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: ['--keep-percent-format', '--py37-plus']
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
args: ['--target-version', 'py37', '--line-length', '100']
- repo: https://github.com/PyCQA/isort
rev: 5.11.4
hooks:
- id: isort
args: ['--profile', 'black', '--line-length', '100']

View file

@ -8,5 +8,8 @@ disable=invalid-name,
too-many-nested-blocks, too-many-nested-blocks,
too-many-instance-attributes, too-many-instance-attributes,
too-many-lines, too-many-lines,
line-too-long,
duplicate-code, duplicate-code,
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=100

View file

@ -49,6 +49,41 @@ Just run `python setup.py install`
To know how to use these libs, you can take a look on *mylib.scripts* content or in *tests* directory. To know how to use these libs, you can take a look on *mylib.scripts* content or in *tests* directory.
## Code Style
[pylint](https://pypi.org/project/pylint/) is used to check for errors and enforces a coding standard, using thoses parameters:
```bash
pylint --extension-pkg-whitelist=cx_Oracle
```
[flake8](https://pypi.org/project/flake8/) is also used to check for errors and enforces a coding standard, using thoses parameters:
```bash
flake8 --max-line-length=100
```
[black](https://pypi.org/project/black/) is used to format the code, using thoses parameters:
```bash
black --target-version py37 --line-length 100
```
[isort](https://pypi.org/project/isort/) is used to format the imports, using those parameter:
```bash
isort --profile black --line-length 100
```
[pyupgrade](https://pypi.org/project/pyupgrade/) is used to automatically upgrade syntax, using those parameters:
```bash
pyupgrade --keep-percent-format --py37-plus
```
**Note:** There is `.pre-commit-config.yaml` to use [pre-commit](https://pre-commit.com/) to automatically run these tools before commits. After cloning the repository, execute `pre-commit install` to install the git hook.
## Copyright ## Copyright
Copyright (c) 2013-2021 Benjamin Renard <brenard@zionetrix.net> Copyright (c) 2013-2021 Benjamin Renard <brenard@zionetrix.net>

View file

@ -10,7 +10,7 @@ def increment_prefix(prefix):
return f'{prefix if prefix else " "} ' return f'{prefix if prefix else " "} '
def pretty_format_value(value, encoding='utf8', prefix=None): def pretty_format_value(value, encoding="utf8", prefix=None):
"""Returned pretty formated value to display""" """Returned pretty formated value to display"""
if isinstance(value, dict): if isinstance(value, dict):
return pretty_format_dict(value, encoding=encoding, prefix=prefix) return pretty_format_dict(value, encoding=encoding, prefix=prefix)
@ -22,10 +22,10 @@ def pretty_format_value(value, encoding='utf8', prefix=None):
return f"'{value}'" return f"'{value}'"
if value is None: if value is None:
return "None" return "None"
return f'{value} ({type(value)})' return f"{value} ({type(value)})"
def pretty_format_value_in_list(value, encoding='utf8', prefix=None): def pretty_format_value_in_list(value, encoding="utf8", prefix=None):
""" """
Returned pretty formated value to display in list Returned pretty formated value to display in list
@ -34,42 +34,31 @@ def pretty_format_value_in_list(value, encoding='utf8', prefix=None):
""" """
prefix = prefix if prefix else "" prefix = prefix if prefix else ""
value = pretty_format_value(value, encoding, prefix) value = pretty_format_value(value, encoding, prefix)
if '\n' in value: if "\n" in value:
inc_prefix = increment_prefix(prefix) inc_prefix = increment_prefix(prefix)
value = "\n" + "\n".join([ value = "\n" + "\n".join([inc_prefix + line for line in value.split("\n")])
inc_prefix + line
for line in value.split('\n')
])
return value return value
def pretty_format_dict(value, encoding='utf8', prefix=None): def pretty_format_dict(value, encoding="utf8", prefix=None):
"""Returned pretty formated dict to display""" """Returned pretty formated dict to display"""
prefix = prefix if prefix else "" prefix = prefix if prefix else ""
result = [] result = []
for key in sorted(value.keys()): for key in sorted(value.keys()):
result.append( result.append(
f'{prefix}- {key} : ' f"{prefix}- {key} : "
+ pretty_format_value_in_list( + pretty_format_value_in_list(value[key], encoding=encoding, prefix=prefix)
value[key],
encoding=encoding,
prefix=prefix
)
) )
return "\n".join(result) return "\n".join(result)
def pretty_format_list(row, encoding='utf8', prefix=None): def pretty_format_list(row, encoding="utf8", prefix=None):
"""Returned pretty formated list to display""" """Returned pretty formated list to display"""
prefix = prefix if prefix else "" prefix = prefix if prefix else ""
result = [] result = []
for idx, values in enumerate(row): for idx, values in enumerate(row):
result.append( result.append(
f'{prefix}- #{idx} : ' f"{prefix}- #{idx} : "
+ pretty_format_value_in_list( + pretty_format_value_in_list(values, encoding=encoding, prefix=prefix)
values,
encoding=encoding,
prefix=prefix
)
) )
return "\n".join(result) return "\n".join(result)

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,7 @@
# -*- coding: utf-8 -*-
""" Basic SQL DB client """ """ Basic SQL DB client """
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -12,6 +9,7 @@ log = logging.getLogger(__name__)
# Exceptions # Exceptions
# #
class DBException(Exception): class DBException(Exception):
"""That is the base exception class for all the other exceptions provided by this module.""" """That is the base exception class for all the other exceptions provided by this module."""
@ -29,7 +27,8 @@ class DBNotImplemented(DBException, RuntimeError):
def __init__(self, method, class_name): def __init__(self, method, class_name):
super().__init__( super().__init__(
"The method {method} is not yet implemented in class {class_name}", "The method {method} is not yet implemented in class {class_name}",
method=method, class_name=class_name method=method,
class_name=class_name,
) )
@ -39,10 +38,7 @@ class DBFailToConnect(DBException, RuntimeError):
""" """
def __init__(self, uri): def __init__(self, uri):
super().__init__( super().__init__("An error occured during database connection ({uri})", uri=uri)
"An error occured during database connection ({uri})",
uri=uri
)
class DBDuplicatedSQLParameter(DBException, KeyError): class DBDuplicatedSQLParameter(DBException, KeyError):
@ -53,8 +49,7 @@ class DBDuplicatedSQLParameter(DBException, KeyError):
def __init__(self, parameter_name): def __init__(self, parameter_name):
super().__init__( super().__init__(
"Duplicated SQL parameter '{parameter_name}'", "Duplicated SQL parameter '{parameter_name}'", parameter_name=parameter_name
parameter_name=parameter_name
) )
@ -65,10 +60,7 @@ class DBUnsupportedWHEREClauses(DBException, TypeError):
""" """
def __init__(self, where_clauses): def __init__(self, where_clauses):
super().__init__( super().__init__("Unsupported WHERE clauses: {where_clauses}", where_clauses=where_clauses)
"Unsupported WHERE clauses: {where_clauses}",
where_clauses=where_clauses
)
class DBInvalidOrderByClause(DBException, TypeError): class DBInvalidOrderByClause(DBException, TypeError):
@ -79,8 +71,9 @@ class DBInvalidOrderByClause(DBException, TypeError):
def __init__(self, order_by): def __init__(self, order_by):
super().__init__( super().__init__(
"Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values (ordering field name and direction)", "Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values"
order_by=order_by " (ordering field name and direction)",
order_by=order_by,
) )
@ -93,11 +86,11 @@ class DB:
self.just_try = just_try self.just_try = just_try
self._conn = None self._conn = None
for arg, value in kwargs.items(): for arg, value in kwargs.items():
setattr(self, f'_{arg}', value) setattr(self, f"_{arg}", value)
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
"""Connect to DB server""" """Connect to DB server"""
raise DBNotImplemented('connect', self.__class__.__name__) raise DBNotImplemented("connect", self.__class__.__name__)
def close(self): def close(self):
"""Close connection with DB server (if opened)""" """Close connection with DB server (if opened)"""
@ -110,12 +103,11 @@ class DB:
log.debug( log.debug(
'Run SQL query "%s" %s', 'Run SQL query "%s" %s',
sql, sql,
"with params = {0}".format( # pylint: disable=consider-using-f-string "with params = {}".format( # pylint: disable=consider-using-f-string
', '.join([ ", ".join([f"{key} = {value}" for key, value in params.items()])
f'{key} = {value}' if params
for key, value in params.items() else "without params"
]) if params else "without params" ),
)
) )
@staticmethod @staticmethod
@ -123,12 +115,11 @@ class DB:
log.exception( log.exception(
'Error during SQL query "%s" %s', 'Error during SQL query "%s" %s',
sql, sql,
"with params = {0}".format( # pylint: disable=consider-using-f-string "with params = {}".format( # pylint: disable=consider-using-f-string
', '.join([ ", ".join([f"{key} = {value}" for key, value in params.items()])
f'{key} = {value}' if params
for key, value in params.items() else "without params"
]) if params else "without params" ),
)
) )
def doSQL(self, sql, params=None): def doSQL(self, sql, params=None):
@ -141,7 +132,7 @@ class DB:
:return: True on success, False otherwise :return: True on success, False otherwise
:rtype: bool :rtype: bool
""" """
raise DBNotImplemented('doSQL', self.__class__.__name__) raise DBNotImplemented("doSQL", self.__class__.__name__)
def doSelect(self, sql, params=None): def doSelect(self, sql, params=None):
""" """
@ -153,7 +144,7 @@ class DB:
:return: List of selected rows as dict on success, False otherwise :return: List of selected rows as dict on success, False otherwise
:rtype: list, bool :rtype: list, bool
""" """
raise DBNotImplemented('doSelect', self.__class__.__name__) raise DBNotImplemented("doSelect", self.__class__.__name__)
# #
# SQL helpers # SQL helpers
@ -162,10 +153,8 @@ class DB:
@staticmethod @staticmethod
def _quote_table_name(table): def _quote_table_name(table):
"""Quote table name""" """Quote table name"""
return '"{0}"'.format( # pylint: disable=consider-using-f-string return '"{}"'.format( # pylint: disable=consider-using-f-string
'"."'.join( '"."'.join(table.split("."))
table.split('.')
)
) )
@staticmethod @staticmethod
@ -176,7 +165,7 @@ class DB:
@staticmethod @staticmethod
def format_param(param): def format_param(param):
"""Format SQL query parameter for prepared query""" """Format SQL query parameter for prepared query"""
return f'%({param})s' return f"%({param})s"
@classmethod @classmethod
def _combine_params(cls, params, to_add=None, **kwargs): def _combine_params(cls, params, to_add=None, **kwargs):
@ -201,7 +190,8 @@ class DB:
- a dict of WHERE clauses with field name as key and WHERE clause value as value - a dict of WHERE clauses with field name as key and WHERE clause value as value
- a list of any of previous valid WHERE clauses - a list of any of previous valid WHERE clauses
:param params: Dict of other already set SQL query parameters (optional) :param params: Dict of other already set SQL query parameters (optional)
:param where_op: SQL operator used to combine WHERE clauses together (optional, default: AND) :param where_op: SQL operator used to combine WHERE clauses together (optional, default:
AND)
:return: A tuple of two elements: raw SQL WHERE combined clauses and parameters on success :return: A tuple of two elements: raw SQL WHERE combined clauses and parameters on success
:rtype: string, bool :rtype: string, bool
@ -209,24 +199,27 @@ class DB:
if params is None: if params is None:
params = {} params = {}
if where_op is None: if where_op is None:
where_op = 'AND' where_op = "AND"
if isinstance(where_clauses, str): if isinstance(where_clauses, str):
return (where_clauses, params) return (where_clauses, params)
if isinstance(where_clauses, tuple) and len(where_clauses) == 2 and isinstance(where_clauses[1], dict): if (
isinstance(where_clauses, tuple)
and len(where_clauses) == 2
and isinstance(where_clauses[1], dict)
):
cls._combine_params(params, where_clauses[1]) cls._combine_params(params, where_clauses[1])
return (where_clauses[0], params) return (where_clauses[0], params)
if isinstance(where_clauses, (list, tuple)): if isinstance(where_clauses, (list, tuple)):
sql_where_clauses = [] sql_where_clauses = []
for where_clause in where_clauses: for where_clause in where_clauses:
sql2, params = cls._format_where_clauses(where_clause, params=params, where_op=where_op) sql2, params = cls._format_where_clauses(
sql_where_clauses.append(sql2) where_clause, params=params, where_op=where_op
return (
f' {where_op} '.join(sql_where_clauses),
params
) )
sql_where_clauses.append(sql2)
return (f" {where_op} ".join(sql_where_clauses), params)
if isinstance(where_clauses, dict): if isinstance(where_clauses, dict):
sql_where_clauses = [] sql_where_clauses = []
@ -235,16 +228,13 @@ class DB:
if field in params: if field in params:
idx = 1 idx = 1
while param in params: while param in params:
param = f'{field}_{idx}' param = f"{field}_{idx}"
idx += 1 idx += 1
cls._combine_params(params, {param: value}) cls._combine_params(params, {param: value})
sql_where_clauses.append( sql_where_clauses.append(
f'{cls._quote_field_name(field)} = {cls.format_param(param)}' f"{cls._quote_field_name(field)} = {cls.format_param(param)}"
)
return (
f' {where_op} '.join(sql_where_clauses),
params
) )
return (f" {where_op} ".join(sql_where_clauses), params)
raise DBUnsupportedWHEREClauses(where_clauses) raise DBUnsupportedWHEREClauses(where_clauses)
@classmethod @classmethod
@ -255,29 +245,26 @@ class DB:
:param sql: The SQL query to complete :param sql: The SQL query to complete
:param params: The dict of parameters of the SQL query to complete :param params: The dict of parameters of the SQL query to complete
:param where_clauses: The WHERE clause (see _format_where_clauses()) :param where_clauses: The WHERE clause (see _format_where_clauses())
:param where_op: SQL operator used to combine WHERE clauses together (optional, default: see _format_where_clauses()) :param where_op: SQL operator used to combine WHERE clauses together (optional, default:
see _format_where_clauses())
:return: :return:
:rtype: A tuple of two elements: raw SQL WHERE combined clauses and parameters :rtype: A tuple of two elements: raw SQL WHERE combined clauses and parameters
""" """
if where_clauses: if where_clauses:
sql_where, params = cls._format_where_clauses(where_clauses, params=params, where_op=where_op) sql_where, params = cls._format_where_clauses(
where_clauses, params=params, where_op=where_op
)
sql += " WHERE " + sql_where sql += " WHERE " + sql_where
return (sql, params) return (sql, params)
def insert(self, table, values, just_try=False): def insert(self, table, values, just_try=False):
"""Run INSERT SQL query""" """Run INSERT SQL query"""
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
sql = 'INSERT INTO {0} ({1}) VALUES ({2})'.format( sql = "INSERT INTO {} ({}) VALUES ({})".format(
self._quote_table_name(table), self._quote_table_name(table),
', '.join([ ", ".join([self._quote_field_name(field) for field in values.keys()]),
self._quote_field_name(field) ", ".join([self.format_param(key) for key in values]),
for field in values.keys()
]),
", ".join([
self.format_param(key)
for key in values
])
) )
if just_try: if just_try:
@ -293,19 +280,18 @@ class DB:
def update(self, table, values, where_clauses, where_op=None, just_try=False): def update(self, table, values, where_clauses, where_op=None, just_try=False):
"""Run UPDATE SQL query""" """Run UPDATE SQL query"""
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
sql = 'UPDATE {0} SET {1}'.format( sql = "UPDATE {} SET {}".format(
self._quote_table_name(table), self._quote_table_name(table),
", ".join([ ", ".join(
f'{self._quote_field_name(key)} = {self.format_param(key)}' [f"{self._quote_field_name(key)} = {self.format_param(key)}" for key in values]
for key in values ),
])
) )
params = values params = values
try: try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error('Fail to add WHERE clauses', exc_info=True) log.error("Fail to add WHERE clauses", exc_info=True)
return False return False
if just_try: if just_try:
@ -318,15 +304,15 @@ class DB:
return False return False
return True return True
def delete(self, table, where_clauses, where_op='AND', just_try=False): def delete(self, table, where_clauses, where_op="AND", just_try=False):
"""Run DELETE SQL query""" """Run DELETE SQL query"""
sql = f'DELETE FROM {self._quote_table_name(table)}' sql = f"DELETE FROM {self._quote_table_name(table)}"
params = {} params = {}
try: try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error('Fail to add WHERE clauses', exc_info=True) log.error("Fail to add WHERE clauses", exc_info=True)
return False return False
if just_try: if just_try:
@ -341,7 +327,7 @@ class DB:
def truncate(self, table, just_try=False): def truncate(self, table, just_try=False):
"""Run TRUNCATE SQL query""" """Run TRUNCATE SQL query"""
sql = f'TRUNCATE TABLE {self._quote_table_name(table)}' sql = f"TRUNCATE TABLE {self._quote_table_name(table)}"
if just_try: if just_try:
log.debug("Just-try mode: execute TRUNCATE query: %s", sql) log.debug("Just-try mode: execute TRUNCATE query: %s", sql)
@ -353,33 +339,36 @@ class DB:
return False return False
return True return True
def select(self, table, where_clauses=None, fields=None, where_op='AND', order_by=None, just_try=False): def select(
self, table, where_clauses=None, fields=None, where_op="AND", order_by=None, just_try=False
):
"""Run SELECT SQL query""" """Run SELECT SQL query"""
sql = "SELECT " sql = "SELECT "
if fields is None: if fields is None:
sql += "*" sql += "*"
elif isinstance(fields, str): elif isinstance(fields, str):
sql += f'{self._quote_field_name(fields)}' sql += f"{self._quote_field_name(fields)}"
else: else:
sql += ', '.join([self._quote_field_name(field) for field in fields]) sql += ", ".join([self._quote_field_name(field) for field in fields])
sql += f' FROM {self._quote_table_name(table)}' sql += f" FROM {self._quote_table_name(table)}"
params = {} params = {}
try: try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error('Fail to add WHERE clauses', exc_info=True) log.error("Fail to add WHERE clauses", exc_info=True)
return False return False
if order_by: if order_by:
if isinstance(order_by, str): if isinstance(order_by, str):
sql += f' ORDER BY {order_by}' sql += f" ORDER BY {order_by}"
elif ( elif (
isinstance(order_by, (list, tuple)) and len(order_by) == 2 isinstance(order_by, (list, tuple))
and len(order_by) == 2
and isinstance(order_by[0], str) and isinstance(order_by[0], str)
and isinstance(order_by[1], str) and isinstance(order_by[1], str)
and order_by[1].upper() in ('ASC', 'UPPER') and order_by[1].upper() in ("ASC", "UPPER")
): ):
sql += f' ORDER BY "{order_by[0]}" {order_by[1].upper()}' sql += f' ORDER BY "{order_by[0]}" {order_by[1].upper()}'
else: else:

View file

@ -1,49 +1,51 @@
# -*- coding: utf-8 -*-
""" Email client to forge and send emails """ """ Email client to forge and send emails """
import email.utils
import logging import logging
import os import os
import smtplib import smtplib
import email.utils
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email.encoders import encode_base64 from email.encoders import encode_base64
from email.mime.base import MIMEBase
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from mako.template import Template as MakoTemplate from mako.template import Template as MakoTemplate
from mylib.config import ConfigurableObject from mylib.config import (
from mylib.config import BooleanOption BooleanOption,
from mylib.config import IntegerOption ConfigurableObject,
from mylib.config import PasswordOption IntegerOption,
from mylib.config import StringOption PasswordOption,
StringOption,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inheritance,too-many-instance-attributes class EmailClient(
ConfigurableObject
): # pylint: disable=useless-object-inheritance,too-many-instance-attributes
""" """
Email client Email client
This class abstract all interactions with the SMTP server. This class abstract all interactions with the SMTP server.
""" """
_config_name = 'email' _config_name = "email"
_config_comment = 'Email' _config_comment = "Email"
_defaults = { _defaults = {
'smtp_host': 'localhost', "smtp_host": "localhost",
'smtp_port': 25, "smtp_port": 25,
'smtp_ssl': False, "smtp_ssl": False,
'smtp_tls': False, "smtp_tls": False,
'smtp_user': None, "smtp_user": None,
'smtp_password': None, "smtp_password": None,
'smtp_debug': False, "smtp_debug": False,
'sender_name': 'No reply', "sender_name": "No reply",
'sender_email': 'noreply@localhost', "sender_email": "noreply@localhost",
'encoding': 'utf-8', "encoding": "utf-8",
'catch_all_addr': None, "catch_all_addr": None,
'just_try': False, "just_try": False,
} }
templates = {} templates = {}
@ -61,55 +63,101 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit
if use_smtp: if use_smtp:
section.add_option( section.add_option(
StringOption, 'smtp_host', default=self._defaults['smtp_host'], StringOption,
comment='SMTP server hostname/IP address') "smtp_host",
default=self._defaults["smtp_host"],
comment="SMTP server hostname/IP address",
)
section.add_option( section.add_option(
IntegerOption, 'smtp_port', default=self._defaults['smtp_port'], IntegerOption,
comment='SMTP server port') "smtp_port",
default=self._defaults["smtp_port"],
comment="SMTP server port",
)
section.add_option( section.add_option(
BooleanOption, 'smtp_ssl', default=self._defaults['smtp_ssl'], BooleanOption,
comment='Use SSL on SMTP server connection') "smtp_ssl",
default=self._defaults["smtp_ssl"],
comment="Use SSL on SMTP server connection",
)
section.add_option( section.add_option(
BooleanOption, 'smtp_tls', default=self._defaults['smtp_tls'], BooleanOption,
comment='Use TLS on SMTP server connection') "smtp_tls",
default=self._defaults["smtp_tls"],
comment="Use TLS on SMTP server connection",
)
section.add_option( section.add_option(
StringOption, 'smtp_user', default=self._defaults['smtp_user'], StringOption,
comment='SMTP authentication username') "smtp_user",
default=self._defaults["smtp_user"],
comment="SMTP authentication username",
)
section.add_option( section.add_option(
PasswordOption, 'smtp_password', default=self._defaults['smtp_password'], PasswordOption,
"smtp_password",
default=self._defaults["smtp_password"],
comment='SMTP authentication password (set to "keyring" to use XDG keyring)', comment='SMTP authentication password (set to "keyring" to use XDG keyring)',
username_option='smtp_user', keyring_value='keyring') username_option="smtp_user",
keyring_value="keyring",
)
section.add_option( section.add_option(
BooleanOption, 'smtp_debug', default=self._defaults['smtp_debug'], BooleanOption,
comment='Enable SMTP debugging') "smtp_debug",
default=self._defaults["smtp_debug"],
comment="Enable SMTP debugging",
)
section.add_option( section.add_option(
StringOption, 'sender_name', default=self._defaults['sender_name'], StringOption,
comment='Sender name') "sender_name",
default=self._defaults["sender_name"],
comment="Sender name",
)
section.add_option( section.add_option(
StringOption, 'sender_email', default=self._defaults['sender_email'], StringOption,
comment='Sender email address') "sender_email",
default=self._defaults["sender_email"],
comment="Sender email address",
)
section.add_option( section.add_option(
StringOption, 'encoding', default=self._defaults['encoding'], StringOption, "encoding", default=self._defaults["encoding"], comment="Email encoding"
comment='Email encoding') )
section.add_option( section.add_option(
StringOption, 'catch_all_addr', default=self._defaults['catch_all_addr'], StringOption,
comment='Catch all sent emails to this specified email address') "catch_all_addr",
default=self._defaults["catch_all_addr"],
comment="Catch all sent emails to this specified email address",
)
if just_try: if just_try:
section.add_option( section.add_option(
BooleanOption, 'just_try', default=self._defaults['just_try'], BooleanOption,
comment='Just-try mode: do not really send emails') "just_try",
default=self._defaults["just_try"],
comment="Just-try mode: do not really send emails",
)
return section return section
def forge_message(self, rcpt_to, subject=None, html_body=None, text_body=None, # pylint: disable=too-many-arguments,too-many-locals def forge_message(
attachment_files=None, attachment_payloads=None, sender_name=None, self,
sender_email=None, encoding=None, template=None, **template_vars): rcpt_to,
subject=None,
html_body=None,
text_body=None, # pylint: disable=too-many-arguments,too-many-locals
attachment_files=None,
attachment_payloads=None,
sender_name=None,
sender_email=None,
encoding=None,
template=None,
**template_vars,
):
""" """
Forge a message Forge a message
:param rcpt_to: The recipient of the email. Could be a tuple(name, email) or just the email of the recipient. :param rcpt_to: The recipient of the email. Could be a tuple(name, email) or
just the email of the recipient.
:param subject: The subject of the email. :param subject: The subject of the email.
:param html_body: The HTML body of the email :param html_body: The HTML body of the email
:param text_body: The plain text body of the email :param text_body: The plain text body of the email
@ -122,64 +170,69 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit
All other parameters will be consider as template variables. All other parameters will be consider as template variables.
""" """
msg = MIMEMultipart('alternative') msg = MIMEMultipart("alternative")
msg['To'] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to msg["To"] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to
msg['From'] = email.utils.formataddr( msg["From"] = email.utils.formataddr(
( (
sender_name or self._get_option('sender_name'), sender_name or self._get_option("sender_name"),
sender_email or self._get_option('sender_email') sender_email or self._get_option("sender_email"),
) )
) )
if subject: if subject:
msg['Subject'] = subject.format(**template_vars) msg["Subject"] = subject.format(**template_vars)
msg['Date'] = email.utils.formatdate(None, True) msg["Date"] = email.utils.formatdate(None, True)
encoding = encoding if encoding else self._get_option('encoding') encoding = encoding if encoding else self._get_option("encoding")
if template: if template:
assert template in self.templates, f'Unknwon template {template}' assert template in self.templates, f"Unknwon template {template}"
# Handle subject from template # Handle subject from template
if not subject: if not subject:
assert self.templates[template].get('subject'), f'No subject defined in template {template}' assert self.templates[template].get(
msg['Subject'] = self.templates[template]['subject'].format(**template_vars) "subject"
), f"No subject defined in template {template}"
msg["Subject"] = self.templates[template]["subject"].format(**template_vars)
# Put HTML part in last one to prefered it # Put HTML part in last one to prefered it
parts = [] parts = []
if self.templates[template].get('text'): if self.templates[template].get("text"):
if isinstance(self.templates[template]['text'], MakoTemplate): if isinstance(self.templates[template]["text"], MakoTemplate):
parts.append((self.templates[template]['text'].render(**template_vars), 'plain')) parts.append(
(self.templates[template]["text"].render(**template_vars), "plain")
)
else: else:
parts.append((self.templates[template]['text'].format(**template_vars), 'plain')) parts.append(
if self.templates[template].get('html'): (self.templates[template]["text"].format(**template_vars), "plain")
if isinstance(self.templates[template]['html'], MakoTemplate): )
parts.append((self.templates[template]['html'].render(**template_vars), 'html')) if self.templates[template].get("html"):
if isinstance(self.templates[template]["html"], MakoTemplate):
parts.append((self.templates[template]["html"].render(**template_vars), "html"))
else: else:
parts.append((self.templates[template]['html'].format(**template_vars), 'html')) parts.append((self.templates[template]["html"].format(**template_vars), "html"))
for body, mime_type in parts: for body, mime_type in parts:
msg.attach(MIMEText(body.encode(encoding), mime_type, _charset=encoding)) msg.attach(MIMEText(body.encode(encoding), mime_type, _charset=encoding))
else: else:
assert subject, 'No subject provided' assert subject, "No subject provided"
if text_body: if text_body:
msg.attach(MIMEText(text_body.encode(encoding), 'plain', _charset=encoding)) msg.attach(MIMEText(text_body.encode(encoding), "plain", _charset=encoding))
if html_body: if html_body:
msg.attach(MIMEText(html_body.encode(encoding), 'html', _charset=encoding)) msg.attach(MIMEText(html_body.encode(encoding), "html", _charset=encoding))
if attachment_files: if attachment_files:
for filepath in attachment_files: for filepath in attachment_files:
with open(filepath, 'rb') as fp: with open(filepath, "rb") as fp:
part = MIMEBase('application', "octet-stream") part = MIMEBase("application", "octet-stream")
part.set_payload(fp.read()) part.set_payload(fp.read())
encode_base64(part) encode_base64(part)
part.add_header( part.add_header(
'Content-Disposition', "Content-Disposition",
f'attachment; filename="{os.path.basename(filepath)}"') f'attachment; filename="{os.path.basename(filepath)}"',
)
msg.attach(part) msg.attach(part)
if attachment_payloads: if attachment_payloads:
for filename, payload in attachment_payloads: for filename, payload in attachment_payloads:
part = MIMEBase('application', "octet-stream") part = MIMEBase("application", "octet-stream")
part.set_payload(payload) part.set_payload(payload)
encode_base64(part) encode_base64(part)
part.add_header( part.add_header("Content-Disposition", f'attachment; filename="{filename}"')
'Content-Disposition',
f'attachment; filename="{filename}"')
msg.attach(part) msg.attach(part)
return msg return msg
@ -192,200 +245,184 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit
:param msg: The message of this email (as MIMEBase or derivated classes) :param msg: The message of this email (as MIMEBase or derivated classes)
:param subject: The subject of the email (only if the message is not provided :param subject: The subject of the email (only if the message is not provided
using msg parameter) using msg parameter)
:param just_try: Enable just try mode (do not really send email, default: as defined on initialization) :param just_try: Enable just try mode (do not really send email, default: as defined on
initialization)
All other parameters will be consider as parameters to forge the message All other parameters will be consider as parameters to forge the message
(only if the message is not provided using msg parameter). (only if the message is not provided using msg parameter).
""" """
msg = msg if msg else self.forge_message(rcpt_to, subject, **forge_args) msg = msg if msg else self.forge_message(rcpt_to, subject, **forge_args)
if just_try or self._get_option('just_try'): if just_try or self._get_option("just_try"):
log.debug('Just-try mode: do not really send this email to %s (subject="%s")', rcpt_to, subject or msg.get('subject', 'No subject')) log.debug(
'Just-try mode: do not really send this email to %s (subject="%s")',
rcpt_to,
subject or msg.get("subject", "No subject"),
)
return True return True
catch_addr = self._get_option('catch_all_addr') catch_addr = self._get_option("catch_all_addr")
if catch_addr: if catch_addr:
log.debug('Catch email originaly send to %s to %s', rcpt_to, catch_addr) log.debug("Catch email originaly send to %s to %s", rcpt_to, catch_addr)
rcpt_to = catch_addr rcpt_to = catch_addr
smtp_host = self._get_option('smtp_host') smtp_host = self._get_option("smtp_host")
smtp_port = self._get_option('smtp_port') smtp_port = self._get_option("smtp_port")
try: try:
if self._get_option('smtp_ssl'): if self._get_option("smtp_ssl"):
logging.info("Establish SSL connection to server %s:%s", smtp_host, smtp_port) logging.info("Establish SSL connection to server %s:%s", smtp_host, smtp_port)
server = smtplib.SMTP_SSL(smtp_host, smtp_port) server = smtplib.SMTP_SSL(smtp_host, smtp_port)
else: else:
logging.info("Establish connection to server %s:%s", smtp_host, smtp_port) logging.info("Establish connection to server %s:%s", smtp_host, smtp_port)
server = smtplib.SMTP(smtp_host, smtp_port) server = smtplib.SMTP(smtp_host, smtp_port)
if self._get_option('smtp_tls'): if self._get_option("smtp_tls"):
logging.info('Start TLS on SMTP connection') logging.info("Start TLS on SMTP connection")
server.starttls() server.starttls()
except smtplib.SMTPException: except smtplib.SMTPException:
log.error('Error connecting to SMTP server %s:%s', smtp_host, smtp_port, exc_info=True) log.error("Error connecting to SMTP server %s:%s", smtp_host, smtp_port, exc_info=True)
return False return False
if self._get_option('smtp_debug'): if self._get_option("smtp_debug"):
server.set_debuglevel(True) server.set_debuglevel(True)
smtp_user = self._get_option('smtp_user') smtp_user = self._get_option("smtp_user")
smtp_password = self._get_option('smtp_password') smtp_password = self._get_option("smtp_password")
if smtp_user and smtp_password: if smtp_user and smtp_password:
try: try:
log.info('Try to authenticate on SMTP connection as %s', smtp_user) log.info("Try to authenticate on SMTP connection as %s", smtp_user)
server.login(smtp_user, smtp_password) server.login(smtp_user, smtp_password)
except smtplib.SMTPException: except smtplib.SMTPException:
log.error( log.error(
'Error authenticating on SMTP server %s:%s with user %s', "Error authenticating on SMTP server %s:%s with user %s",
smtp_host, smtp_port, smtp_user, exc_info=True) smtp_host,
smtp_port,
smtp_user,
exc_info=True,
)
return False return False
error = False error = False
try: try:
log.info('Sending email to %s', rcpt_to) log.info("Sending email to %s", rcpt_to)
server.sendmail( server.sendmail(
self._get_option('sender_email'), self._get_option("sender_email"),
[rcpt_to[1] if isinstance(rcpt_to, tuple) else rcpt_to], [rcpt_to[1] if isinstance(rcpt_to, tuple) else rcpt_to],
msg.as_string() msg.as_string(),
) )
except smtplib.SMTPException: except smtplib.SMTPException:
error = True error = True
log.error('Error sending email to %s', rcpt_to, exc_info=True) log.error("Error sending email to %s", rcpt_to, exc_info=True)
finally: finally:
server.quit() server.quit()
return not error return not error
if __name__ == '__main__': if __name__ == "__main__":
# Run tests # Run tests
import argparse
import datetime import datetime
import sys import sys
import argparse
# Options parser # Options parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-v', '--verbose', "-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode"
action="store_true",
dest="verbose",
help="Enable verbose mode"
) )
parser.add_argument( parser.add_argument(
'-d', '--debug', "-d", "--debug", action="store_true", dest="debug", help="Enable debug mode"
action="store_true",
dest="debug",
help="Enable debug mode"
) )
parser.add_argument( parser.add_argument(
'-l', '--log-file', "-l", "--log-file", action="store", type=str, dest="logfile", help="Log file path"
action="store",
type=str,
dest="logfile",
help="Log file path"
) )
parser.add_argument( parser.add_argument(
'-j', '--just-try', "-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode"
action="store_true",
dest="just_try",
help="Enable just-try mode"
) )
email_opts = parser.add_argument_group('Email options') email_opts = parser.add_argument_group("Email options")
email_opts.add_argument( email_opts.add_argument(
'-H', '--smtp-host', "-H", "--smtp-host", action="store", type=str, dest="email_smtp_host", help="SMTP host"
action="store",
type=str,
dest="email_smtp_host",
help="SMTP host"
) )
email_opts.add_argument( email_opts.add_argument(
'-P', '--smtp-port', "-P", "--smtp-port", action="store", type=int, dest="email_smtp_port", help="SMTP port"
action="store",
type=int,
dest="email_smtp_port",
help="SMTP port"
) )
email_opts.add_argument( email_opts.add_argument(
'-S', '--smtp-ssl', "-S", "--smtp-ssl", action="store_true", dest="email_smtp_ssl", help="Use SSL"
action="store_true",
dest="email_smtp_ssl",
help="Use SSL"
) )
email_opts.add_argument( email_opts.add_argument(
'-T', '--smtp-tls', "-T", "--smtp-tls", action="store_true", dest="email_smtp_tls", help="Use TLS"
action="store_true",
dest="email_smtp_tls",
help="Use TLS"
) )
email_opts.add_argument( email_opts.add_argument(
'-u', '--smtp-user', "-u", "--smtp-user", action="store", type=str, dest="email_smtp_user", help="SMTP username"
action="store",
type=str,
dest="email_smtp_user",
help="SMTP username"
) )
email_opts.add_argument( email_opts.add_argument(
'-p', '--smtp-password', "-p",
"--smtp-password",
action="store", action="store",
type=str, type=str,
dest="email_smtp_password", dest="email_smtp_password",
help="SMTP password" help="SMTP password",
) )
email_opts.add_argument( email_opts.add_argument(
'-D', '--smtp-debug', "-D",
"--smtp-debug",
action="store_true", action="store_true",
dest="email_smtp_debug", dest="email_smtp_debug",
help="Debug SMTP connection" help="Debug SMTP connection",
) )
email_opts.add_argument( email_opts.add_argument(
'-e', '--email-encoding', "-e",
"--email-encoding",
action="store", action="store",
type=str, type=str,
dest="email_encoding", dest="email_encoding",
help="SMTP encoding" help="SMTP encoding",
) )
email_opts.add_argument( email_opts.add_argument(
'-f', '--sender-name', "-f",
"--sender-name",
action="store", action="store",
type=str, type=str,
dest="email_sender_name", dest="email_sender_name",
help="Sender name" help="Sender name",
) )
email_opts.add_argument( email_opts.add_argument(
'-F', '--sender-email', "-F",
"--sender-email",
action="store", action="store",
type=str, type=str,
dest="email_sender_email", dest="email_sender_email",
help="Sender email" help="Sender email",
) )
email_opts.add_argument( email_opts.add_argument(
'-C', '--catch-all', "-C",
"--catch-all",
action="store", action="store",
type=str, type=str,
dest="email_catch_all", dest="email_catch_all",
help="Catch all sent email: specify catch recipient email address" help="Catch all sent email: specify catch recipient email address",
) )
test_opts = parser.add_argument_group('Test email options') test_opts = parser.add_argument_group("Test email options")
test_opts.add_argument( test_opts.add_argument(
'-t', '--to', "-t",
"--to",
action="store", action="store",
type=str, type=str,
dest="test_to", dest="test_to",
@ -393,7 +430,8 @@ if __name__ == '__main__':
) )
test_opts.add_argument( test_opts.add_argument(
'-m', '--mako', "-m",
"--mako",
action="store_true", action="store_true",
dest="test_mako", dest="test_mako",
help="Test mako templating", help="Test mako templating",
@ -402,11 +440,11 @@ if __name__ == '__main__':
options = parser.parse_args() options = parser.parse_args()
if not options.test_to: if not options.test_to:
parser.error('You must specify test email recipient using -t/--to parameter') parser.error("You must specify test email recipient using -t/--to parameter")
sys.exit(1) sys.exit(1)
# Initialize logs # Initialize logs
logformat = '%(asctime)s - Test EmailClient - %(levelname)s - %(message)s' logformat = "%(asctime)s - Test EmailClient - %(levelname)s - %(message)s"
if options.debug: if options.debug:
loglevel = logging.DEBUG loglevel = logging.DEBUG
elif options.verbose: elif options.verbose:
@ -421,9 +459,10 @@ if __name__ == '__main__':
if options.email_smtp_user and not options.email_smtp_password: if options.email_smtp_user and not options.email_smtp_password:
import getpass import getpass
options.email_smtp_password = getpass.getpass('Please enter SMTP password: ')
logging.info('Initialize Email client') options.email_smtp_password = getpass.getpass("Please enter SMTP password: ")
logging.info("Initialize Email client")
email_client = EmailClient( email_client = EmailClient(
smtp_host=options.email_smtp_host, smtp_host=options.email_smtp_host,
smtp_port=options.email_smtp_port, smtp_port=options.email_smtp_port,
@ -441,20 +480,24 @@ if __name__ == '__main__':
test=dict( test=dict(
subject="Test email", subject="Test email",
text=( text=(
"Just a test email sent at {sent_date}." if not options.test_mako else "Just a test email sent at {sent_date}."
MakoTemplate("Just a test email sent at ${sent_date}.") if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
), ),
html=( html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else "<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>") if not options.test_mako
) else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
) )
),
) )
),
) )
logging.info('Send a test email to %s', options.test_to) logging.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
logging.info('Test email sent') logging.info("Test email sent")
sys.exit(0) sys.exit(0)
logging.error('Fail to send test email') logging.error("Fail to send test email")
sys.exit(1) sys.exit(1)

View file

@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
""" LDAP server connection helper """ """ LDAP server connection helper """
import copy import copy
import datetime import datetime
import logging import logging
import pytz
import dateutil.parser import dateutil.parser
import dateutil.tz import dateutil.tz
import ldap import ldap
import pytz
from ldap import modlist from ldap import modlist
from ldap.controls import SimplePagedResultsControl from ldap.controls import SimplePagedResultsControl
from ldap.controls.simple import RelaxRulesControl from ldap.controls.simple import RelaxRulesControl
@ -18,34 +16,28 @@ from ldap.dn import escape_dn_chars, explode_dn
from mylib import pretty_format_dict from mylib import pretty_format_dict
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DEFAULT_ENCODING = 'utf-8' DEFAULT_ENCODING = "utf-8"
def decode_ldap_value(value, encoding='utf-8'): def decode_ldap_value(value, encoding="utf-8"):
"""Decoding LDAP attribute values helper""" """Decoding LDAP attribute values helper"""
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode(encoding) return value.decode(encoding)
if isinstance(value, list): if isinstance(value, list):
return [decode_ldap_value(v) for v in value] return [decode_ldap_value(v) for v in value]
if isinstance(value, dict): if isinstance(value, dict):
return dict( return {key: decode_ldap_value(values) for key, values in value.items()}
(key, decode_ldap_value(values))
for key, values in value.items()
)
return value return value
def encode_ldap_value(value, encoding='utf-8'): def encode_ldap_value(value, encoding="utf-8"):
"""Encoding LDAP attribute values helper""" """Encoding LDAP attribute values helper"""
if isinstance(value, str): if isinstance(value, str):
return value.encode(encoding) return value.encode(encoding)
if isinstance(value, list): if isinstance(value, list):
return [encode_ldap_value(v) for v in value] return [encode_ldap_value(v) for v in value]
if isinstance(value, dict): if isinstance(value, dict):
return dict( return {key: encode_ldap_value(values) for key, values in value.items()}
(key, encode_ldap_value(values))
for key, values in value.items()
)
return value return value
@ -59,9 +51,17 @@ class LdapServer:
con = 0 con = 0
def __init__(self, uri, dn=None, pwd=None, v2=None, def __init__(
raiseOnError=False, logger=False, checkCert=True, self,
disableReferral=False): uri,
dn=None,
pwd=None,
v2=None,
raiseOnError=False,
logger=False,
checkCert=True,
disableReferral=False,
):
self.uri = uri self.uri = uri
self.dn = dn self.dn = dn
self.pwd = pwd self.pwd = pwd
@ -98,26 +98,27 @@ class LdapServer:
if self.dn: if self.dn:
con.simple_bind_s(self.dn, self.pwd) con.simple_bind_s(self.dn, self.pwd)
elif self.uri.startswith('ldapi://'): elif self.uri.startswith("ldapi://"):
con.sasl_interactive_bind_s("", ldap.sasl.external()) con.sasl_interactive_bind_s("", ldap.sasl.external())
self.con = con self.con = con
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(
f'LdapServer - Error connecting and binding to LDAP server: {e}', f"LdapServer - Error connecting and binding to LDAP server: {e}",
logging.CRITICAL) logging.CRITICAL,
)
return False return False
return True return True
@staticmethod @staticmethod
def get_scope(scope): def get_scope(scope):
"""Map scope parameter to python-ldap value""" """Map scope parameter to python-ldap value"""
if scope == 'base': if scope == "base":
return ldap.SCOPE_BASE # pylint: disable=no-member return ldap.SCOPE_BASE # pylint: disable=no-member
if scope == 'one': if scope == "one":
return ldap.SCOPE_ONELEVEL # pylint: disable=no-member return ldap.SCOPE_ONELEVEL # pylint: disable=no-member
if scope == 'sub': if scope == "sub":
return ldap.SCOPE_SUBTREE # pylint: disable=no-member return ldap.SCOPE_SUBTREE # pylint: disable=no-member
raise Exception(f'Unknown LDAP scope "{scope}"') raise Exception(f'Unknown LDAP scope "{scope}"')
@ -126,9 +127,9 @@ class LdapServer:
assert self.con or self.connect() assert self.con or self.connect()
res_id = self.con.search( res_id = self.con.search(
basedn, basedn,
self.get_scope(scope if scope else 'sub'), self.get_scope(scope if scope else "sub"),
filterstr if filterstr else '(objectClass=*)', filterstr if filterstr else "(objectClass=*)",
attrs if attrs else [] attrs if attrs else [],
) )
ret = {} ret = {}
c = 0 c = 0
@ -143,64 +144,63 @@ class LdapServer:
def get_object(self, dn, filterstr=None, attrs=None): def get_object(self, dn, filterstr=None, attrs=None):
"""Retrieve a LDAP object specified by its DN""" """Retrieve a LDAP object specified by its DN"""
result = self.search(dn, filterstr=filterstr, scope='base', attrs=attrs) result = self.search(dn, filterstr=filterstr, scope="base", attrs=attrs)
return result[dn] if dn in result else None return result[dn] if dn in result else None
def paged_search(self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, def paged_search(
sizelimit=None): self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, sizelimit=None
):
"""Run a paged search on LDAP server""" """Run a paged search on LDAP server"""
assert not self.v2, "Paged search is not available on LDAP version 2" assert not self.v2, "Paged search is not available on LDAP version 2"
assert self.con or self.connect() assert self.con or self.connect()
# Set parameters default values (if not defined) # Set parameters default values (if not defined)
filterstr = filterstr if filterstr else '(objectClass=*)' filterstr = filterstr if filterstr else "(objectClass=*)"
attrs = attrs if attrs else [] attrs = attrs if attrs else []
scope = scope if scope else 'sub' scope = scope if scope else "sub"
pagesize = pagesize if pagesize else 500 pagesize = pagesize if pagesize else 500
# Initialize SimplePagedResultsControl object # Initialize SimplePagedResultsControl object
page_control = SimplePagedResultsControl( page_control = SimplePagedResultsControl(
True, True, size=pagesize, cookie="" # Start without cookie
size=pagesize,
cookie='' # Start without cookie
) )
ret = {} ret = {}
pages_count = 0 pages_count = 0
self.logger.debug( self.logger.debug(
"LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d and attrs=%s", "LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d"
" and attrs=%s",
basedn, basedn,
filterstr, filterstr,
scope, scope,
pagesize, pagesize,
attrs attrs,
) )
while True: while True:
pages_count += 1 pages_count += 1
self.logger.debug( self.logger.debug(
"LdapServer - Paged search: request page %d with a maximum of %d objects (current total count: %d)", "LdapServer - Paged search: request page %d with a maximum of %d objects"
" (current total count: %d)",
pages_count, pages_count,
pagesize, pagesize,
len(ret) len(ret),
) )
try: try:
res_id = self.con.search_ext( res_id = self.con.search_ext(
basedn, basedn, self.get_scope(scope), filterstr, attrs, serverctrls=[page_control]
self.get_scope(scope),
filterstr,
attrs,
serverctrls=[page_control]
) )
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(
f'LdapServer - Error running paged search on LDAP server: {e}', f"LdapServer - Error running paged search on LDAP server: {e}", logging.CRITICAL
logging.CRITICAL) )
return False return False
try: try:
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable # pylint: disable=unused-variable
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id)
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(
f'LdapServer - Error pulling paged search result from LDAP server: {e}', f"LdapServer - Error pulling paged search result from LDAP server: {e}",
logging.CRITICAL) logging.CRITICAL,
)
return False return False
# Detect and catch PagedResultsControl answer from rctrls # Detect and catch PagedResultsControl answer from rctrls
@ -214,8 +214,9 @@ class LdapServer:
# If PagedResultsControl answer not detected, paged serach # If PagedResultsControl answer not detected, paged serach
if not result_page_control: if not result_page_control:
self._error( self._error(
'LdapServer - Server ignores RFC2696 control, paged search can not works', "LdapServer - Server ignores RFC2696 control, paged search can not works",
logging.CRITICAL) logging.CRITICAL,
)
return False return False
# Store results of this page # Store results of this page
@ -236,7 +237,12 @@ class LdapServer:
# Otherwise, set cookie for the next search # Otherwise, set cookie for the next search
page_control.cookie = result_page_control.cookie page_control.cookie = result_page_control.cookie
self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize) self.logger.debug(
"LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)",
len(ret),
pages_count,
pagesize,
)
return ret return ret
def add_object(self, dn, attrs, encode=False): def add_object(self, dn, attrs, encode=False):
@ -248,7 +254,7 @@ class LdapServer:
self.con.add_s(dn, ldif) self.con.add_s(dn, ldif)
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error(f'LdapServer - Error adding {dn}: {e}', logging.ERROR) self._error(f"LdapServer - Error adding {dn}: {e}", logging.ERROR)
return False return False
@ -258,7 +264,7 @@ class LdapServer:
ldif = modlist.modifyModlist( ldif = modlist.modifyModlist(
encode_ldap_value(old) if encode else old, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new, encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else [],
) )
if not ldif: if not ldif:
return True return True
@ -271,8 +277,8 @@ class LdapServer:
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(
f'LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}', f"LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}", logging.ERROR
logging.ERROR) )
return False return False
@staticmethod @staticmethod
@ -281,7 +287,7 @@ class LdapServer:
ldif = modlist.modifyModlist( ldif = modlist.modifyModlist(
encode_ldap_value(old) if encode else old, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new, encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else [],
) )
if not ldif: if not ldif:
return False return False
@ -293,44 +299,50 @@ class LdapServer:
return modlist.modifyModlist( return modlist.modifyModlist(
encode_ldap_value(old) if encode else old, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new, encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else [],
) )
@staticmethod @staticmethod
def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False): def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False):
""" Format changes (modlist) on an object based on its old and new attributes values to display/log it """ """
Format changes (modlist) on an object based on its old and new attributes values to
display/log it
"""
msg = [] msg = []
prefix = prefix if prefix else '' prefix = prefix if prefix else ""
for (op, attr, val) in modlist.modifyModlist( for op, attr, val in modlist.modifyModlist(
encode_ldap_value(old) if encode else old, encode_ldap_value(old) if encode else old,
encode_ldap_value(new) if encode else new, encode_ldap_value(new) if encode else new,
ignore_attr_types=ignore_attrs if ignore_attrs else [] ignore_attr_types=ignore_attrs if ignore_attrs else [],
): ):
if op == ldap.MOD_ADD: # pylint: disable=no-member if op == ldap.MOD_ADD: # pylint: disable=no-member
op = 'ADD' op = "ADD"
elif op == ldap.MOD_DELETE: # pylint: disable=no-member elif op == ldap.MOD_DELETE: # pylint: disable=no-member
op = 'DELETE' op = "DELETE"
elif op == ldap.MOD_REPLACE: # pylint: disable=no-member elif op == ldap.MOD_REPLACE: # pylint: disable=no-member
op = 'REPLACE' op = "REPLACE"
else: else:
op = f'UNKNOWN (={op})' op = f"UNKNOWN (={op})"
if val is None and op == 'DELETE': if val is None and op == "DELETE":
msg.append(f'{prefix} - {op} {attr}') msg.append(f"{prefix} - {op} {attr}")
else: else:
msg.append(f'{prefix} - {op} {attr}: {val}') msg.append(f"{prefix} - {op} {attr}: {val}")
return '\n'.join(msg) return "\n".join(msg)
def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True): def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True):
"""Rename an object in LDAP directory""" """Rename an object in LDAP directory"""
# If new_rdn is a complete DN, split new RDN and new superior DN # If new_rdn is a complete DN, split new RDN and new superior DN
if len(explode_dn(new_rdn)) > 1: if len(explode_dn(new_rdn)) > 1:
self.logger.debug( self.logger.debug(
"LdapServer - Rename with a full new DN detected (%s): split new RDN and new superior DN", "LdapServer - Rename with a full new DN detected (%s): split new RDN and new"
new_rdn " superior DN",
new_rdn,
) )
assert new_sup is None, "You can't provide a complete DN as new_rdn and also provide new_sup parameter" assert (
new_sup is None
), "You can't provide a complete DN as new_rdn and also provide new_sup parameter"
new_dn_parts = explode_dn(new_rdn) new_dn_parts = explode_dn(new_rdn)
new_sup = ','.join(new_dn_parts[1:]) new_sup = ",".join(new_dn_parts[1:])
new_rdn = new_dn_parts[0] new_rdn = new_dn_parts[0]
assert self.con or self.connect() assert self.con or self.connect()
try: try:
@ -339,16 +351,16 @@ class LdapServer:
dn, dn,
new_rdn, new_rdn,
"same" if new_sup is None else new_sup, "same" if new_sup is None else new_sup,
delete_old delete_old,
) )
self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old) self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old)
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(
f'LdapServer - Error renaming {dn} in {new_rdn} ' f"LdapServer - Error renaming {dn} in {new_rdn} "
f'(new superior: {"same" if new_sup is None else new_sup}, ' f'(new superior: {"same" if new_sup is None else new_sup}, '
f'delete old: {delete_old}): {e}', f"delete old: {delete_old}): {e}",
logging.ERROR logging.ERROR,
) )
return False return False
@ -361,8 +373,7 @@ class LdapServer:
self.con.delete_s(dn) self.con.delete_s(dn)
return True return True
except ldap.LDAPError as e: # pylint: disable=no-member except ldap.LDAPError as e: # pylint: disable=no-member
self._error( self._error(f"LdapServer - Error deleting {dn}: {e}", logging.ERROR)
f'LdapServer - Error deleting {dn}: {e}', logging.ERROR)
return False return False
@ -416,11 +427,13 @@ class LdapClient:
# Cache objects # Cache objects
_cached_objects = None _cached_objects = None
def __init__(self, options=None, options_prefix=None, config=None, config_section=None, initialize=False): def __init__(
self, options=None, options_prefix=None, config=None, config_section=None, initialize=False
):
self._options = options if options else {} self._options = options if options else {}
self._options_prefix = options_prefix if options_prefix else 'ldap_' self._options_prefix = options_prefix if options_prefix else "ldap_"
self._config = config if config else None self._config = config if config else None
self._config_section = config_section if config_section else 'ldap' self._config_section = config_section if config_section else "ldap"
self._cached_objects = {} self._cached_objects = {}
if initialize: if initialize:
self.initialize() self.initialize()
@ -433,7 +446,7 @@ class LdapClient:
if self._config and self._config.defined(self._config_section, option): if self._config and self._config.defined(self._config_section, option):
return self._config.get(self._config_section, option) return self._config.get(self._config_section, option)
assert not required, f'Options {option} not defined' assert not required, f"Options {option} not defined"
return default return default
@ -441,41 +454,45 @@ class LdapClient:
def _just_try(self): def _just_try(self):
"""Check if just-try mode is enabled""" """Check if just-try mode is enabled"""
return self._get_option( return self._get_option(
'just_try', default=( "just_try", default=(self._config.get_option("just_try") if self._config else False)
self._config.get_option('just_try') if self._config
else False
)
) )
def configure(self, comment=None, **kwargs): def configure(self, comment=None, **kwargs):
"""Configure options on registered mylib.Config object""" """Configure options on registered mylib.Config object"""
assert self._config, "mylib.Config object not registered. Must be passed to __init__ as config keyword argument." assert self._config, (
"mylib.Config object not registered. Must be passed to __init__ as config keyword"
" argument."
)
# Load configuration option types only here to avoid global # Load configuration option types only here to avoid global
# dependency of ldap module with config one. # dependency of ldap module with config one.
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from mylib.config import BooleanOption, StringOption, PasswordOption from mylib.config import BooleanOption, PasswordOption, StringOption
section = self._config.add_section( section = self._config.add_section(
self._config_section, self._config_section,
comment=comment if comment else 'LDAP connection', comment=comment if comment else "LDAP connection",
loaded_callback=self.initialize, **kwargs) loaded_callback=self.initialize,
**kwargs,
)
section.add_option( section.add_option(
StringOption, 'uri', default='ldap://localhost', StringOption, "uri", default="ldap://localhost", comment="LDAP server URI"
comment='LDAP server URI') )
section.add_option(StringOption, "binddn", comment="LDAP Bind DN")
section.add_option( section.add_option(
StringOption, 'binddn', comment='LDAP Bind DN') PasswordOption,
section.add_option( "bindpwd",
PasswordOption, 'bindpwd',
comment='LDAP Bind password (set to "keyring" to use XDG keyring)', comment='LDAP Bind password (set to "keyring" to use XDG keyring)',
username_option='binddn', keyring_value='keyring') username_option="binddn",
keyring_value="keyring",
)
section.add_option( section.add_option(
BooleanOption, 'checkcert', default=True, BooleanOption, "checkcert", default=True, comment="Check LDAP certificate"
comment='Check LDAP certificate') )
section.add_option( section.add_option(
BooleanOption, 'disablereferral', default=False, BooleanOption, "disablereferral", default=False, comment="Disable referral following"
comment='Disable referral following') )
return section return section
@ -483,14 +500,16 @@ class LdapClient:
"""Initialize LDAP connection""" """Initialize LDAP connection"""
if loaded_config: if loaded_config:
self.config = loaded_config self.config = loaded_config
uri = self._get_option('uri', required=True) uri = self._get_option("uri", required=True)
binddn = self._get_option('binddn') binddn = self._get_option("binddn")
log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else 'annonymous') log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else "annonymous")
self._conn = LdapServer( self._conn = LdapServer(
uri, dn=binddn, pwd=self._get_option('bindpwd'), uri,
checkCert=self._get_option('checkcert'), dn=binddn,
disableReferral=self._get_option('disablereferral'), pwd=self._get_option("bindpwd"),
raiseOnError=True checkCert=self._get_option("checkcert"),
disableReferral=self._get_option("disablereferral"),
raiseOnError=True,
) )
# Reset cache # Reset cache
self._cached_objects = {} self._cached_objects = {}
@ -503,8 +522,8 @@ class LdapClient:
if isinstance(value, str): if isinstance(value, str):
return value return value
return value.decode( return value.decode(
self._get_option('encoding', default=DEFAULT_ENCODING), self._get_option("encoding", default=DEFAULT_ENCODING),
self._get_option('encoding_error_policy', default='ignore') self._get_option("encoding_error_policy", default="ignore"),
) )
def encode(self, value): def encode(self, value):
@ -513,7 +532,7 @@ class LdapClient:
return [self.encode(v) for v in value] return [self.encode(v) for v in value]
if isinstance(value, bytes): if isinstance(value, bytes):
return value return value
return value.encode(self._get_option('encoding', default=DEFAULT_ENCODING)) return value.encode(self._get_option("encoding", default=DEFAULT_ENCODING))
def _get_obj(self, dn, attrs): def _get_obj(self, dn, attrs):
""" """
@ -548,8 +567,17 @@ class LdapClient:
return vals if all_values else vals[0] return vals if all_values else vals[0]
return default if default or not all_values else [] return default if default or not all_values else []
def get_objects(self, name, filterstr, basedn, attrs, key_attr=None, warn=True, def get_objects(
paged_search=False, pagesize=None): self,
name,
filterstr,
basedn,
attrs,
key_attr=None,
warn=True,
paged_search=False,
pagesize=None,
):
""" """
Retrieve objects from LDAP Retrieve objects from LDAP
@ -568,25 +596,28 @@ class LdapClient:
(optional, default: see LdapServer.paged_search) (optional, default: see LdapServer.paged_search)
""" """
if name in self._cached_objects: if name in self._cached_objects:
log.debug('Retreived %s objects from cache', name) log.debug("Retreived %s objects from cache", name)
else: else:
assert self._conn or self.initialize() assert self._conn or self.initialize()
log.debug('Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn) log.debug(
'Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn
)
if paged_search: if paged_search:
ldap_data = self._conn.paged_search( ldap_data = self._conn.paged_search(
basedn=basedn, filterstr=filterstr, attrs=attrs, basedn=basedn, filterstr=filterstr, attrs=attrs, pagesize=pagesize
pagesize=pagesize
) )
else: else:
ldap_data = self._conn.search( ldap_data = self._conn.search(
basedn=basedn, filterstr=filterstr, attrs=attrs, basedn=basedn,
filterstr=filterstr,
attrs=attrs,
) )
if not ldap_data: if not ldap_data:
if warn: if warn:
log.warning('No %s found in LDAP', name) log.warning("No %s found in LDAP", name)
else: else:
log.debug('No %s found in LDAP', name) log.debug("No %s found in LDAP", name)
return {} return {}
objects = {} objects = {}
@ -596,12 +627,12 @@ class LdapClient:
continue continue
objects[obj_dn] = self._get_obj(obj_dn, obj_attrs) objects[obj_dn] = self._get_obj(obj_dn, obj_attrs)
self._cached_objects[name] = objects self._cached_objects[name] = objects
if not key_attr or key_attr == 'dn': if not key_attr or key_attr == "dn":
return self._cached_objects[name] return self._cached_objects[name]
return dict( return {
(self.get_attr(self._cached_objects[name][dn], key_attr), self._cached_objects[name][dn]) self.get_attr(self._cached_objects[name][dn], key_attr): self._cached_objects[name][dn]
for dn in self._cached_objects[name] for dn in self._cached_objects[name]
) }
def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True): def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True):
""" """
@ -620,11 +651,14 @@ class LdapClient:
(optional, default: True) (optional, default: True)
""" """
assert self._conn or self.initialize() assert self._conn or self.initialize()
log.debug('Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")', type_name, object_name, filterstr, basedn) log.debug(
ldap_data = self._conn.search( 'Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")',
basedn=basedn, filterstr=filterstr, type_name,
attrs=attrs object_name,
filterstr,
basedn,
) )
ldap_data = self._conn.search(basedn=basedn, filterstr=filterstr, attrs=attrs)
if not ldap_data: if not ldap_data:
if warn: if warn:
@ -635,7 +669,8 @@ class LdapClient:
if len(ldap_data) > 1: if len(ldap_data) > 1:
raise LdapClientException( raise LdapClientException(
f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}') f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}'
)
dn = next(iter(ldap_data)) dn = next(iter(ldap_data))
return self._get_obj(dn, ldap_data[dn]) return self._get_obj(dn, ldap_data[dn])
@ -659,9 +694,9 @@ class LdapClient:
populate_cache_method() populate_cache_method()
if type_name not in self._cached_objects: if type_name not in self._cached_objects:
if warn: if warn:
log.warning('No %s found in LDAP', type_name) log.warning("No %s found in LDAP", type_name)
else: else:
log.debug('No %s found in LDAP', type_name) log.debug("No %s found in LDAP", type_name)
return None return None
if dn not in self._cached_objects[type_name]: if dn not in self._cached_objects[type_name]:
if warn: if warn:
@ -686,7 +721,9 @@ class LdapClient:
return value in cls.get_attr(obj, attr, all_values=True) return value in cls.get_attr(obj, attr, all_values=True)
return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)] return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)]
def get_object_by_attr(self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True): def get_object_by_attr(
self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True
):
""" """
Retrieve an LDAP object specified by one of its attribute Retrieve an LDAP object specified by one of its attribute
@ -708,15 +745,15 @@ class LdapClient:
populate_cache_method() populate_cache_method()
if type_name not in self._cached_objects: if type_name not in self._cached_objects:
if warn: if warn:
log.warning('No %s found in LDAP', type_name) log.warning("No %s found in LDAP", type_name)
else: else:
log.debug('No %s found in LDAP', type_name) log.debug("No %s found in LDAP", type_name)
return None return None
matched = dict( matched = {
(dn, obj) dn: obj
for dn, obj in self._cached_objects[type_name].items() for dn, obj in self._cached_objects[type_name].items()
if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive) if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive)
) }
if not matched: if not matched:
if warn: if warn:
log.warning('No %s found with %s="%s"', type_name, attr, value) log.warning('No %s found with %s="%s"', type_name, attr, value)
@ -726,7 +763,8 @@ class LdapClient:
if len(matched) > 1: if len(matched) > 1:
raise LdapClientException( raise LdapClientException(
f'More than one {type_name} with {attr}="{value}" found: ' f'More than one {type_name} with {attr}="{value}" found: '
f'{" / ".join(matched.keys())}') f'{" / ".join(matched.keys())}'
)
dn = next(iter(matched)) dn = next(iter(matched))
return matched[dn] return matched[dn]
@ -742,7 +780,7 @@ class LdapClient:
old = {} old = {}
new = {} new = {}
protected_attrs = [a.lower() for a in protected_attrs or []] protected_attrs = [a.lower() for a in protected_attrs or []]
protected_attrs.append('dn') protected_attrs.append("dn")
# New/updated attributes # New/updated attributes
for attr in attrs: for attr in attrs:
if protected_attrs and attr.lower() in protected_attrs: if protected_attrs and attr.lower() in protected_attrs:
@ -755,7 +793,11 @@ class LdapClient:
# Deleted attributes # Deleted attributes
for attr in ldap_obj: for attr in ldap_obj:
if (not protected_attrs or attr.lower() not in protected_attrs) and ldap_obj[attr] and attr not in attrs: if (
(not protected_attrs or attr.lower() not in protected_attrs)
and ldap_obj[attr]
and attr not in attrs
):
old[attr] = self.encode(ldap_obj[attr]) old[attr] = self.encode(ldap_obj[attr])
if old == new: if old == new:
return None return None
@ -771,8 +813,7 @@ class LdapClient:
""" """
assert self._conn or self.initialize() assert self._conn or self.initialize()
return self._conn.format_changes( return self._conn.format_changes(
changes[0], changes[1], changes[0], changes[1], ignore_attrs=protected_attrs, prefix=prefix
ignore_attrs=protected_attrs, prefix=prefix
) )
def update_need(self, changes, protected_attrs=None): def update_need(self, changes, protected_attrs=None):
@ -784,10 +825,7 @@ class LdapClient:
if changes is None: if changes is None:
return False return False
assert self._conn or self.initialize() assert self._conn or self.initialize()
return self._conn.update_need( return self._conn.update_need(changes[0], changes[1], ignore_attrs=protected_attrs)
changes[0], changes[1],
ignore_attrs=protected_attrs
)
def add_object(self, dn, attrs): def add_object(self, dn, attrs):
""" """
@ -796,21 +834,19 @@ class LdapClient:
:param dn: The LDAP object DN :param dn: The LDAP object DN
:param attrs: The LDAP object attributes (as dict) :param attrs: The LDAP object attributes (as dict)
""" """
attrs = dict( attrs = {attr: self.encode(values) for attr, values in attrs.items() if attr != "dn"}
(attr, self.encode(values))
for attr, values in attrs.items()
if attr != 'dn'
)
try: try:
if self._just_try: if self._just_try:
log.debug('Just-try mode : do not really add object in LDAP') log.debug("Just-try mode : do not really add object in LDAP")
return True return True
assert self._conn or self.initialize() assert self._conn or self.initialize()
return self._conn.add_object(dn, attrs) return self._conn.add_object(dn, attrs)
except LdapServerException: except LdapServerException:
log.error( log.error(
"An error occurred adding object %s in LDAP:\n%s\n", "An error occurred adding object %s in LDAP:\n%s\n",
dn, pretty_format_dict(attrs), exc_info=True dn,
pretty_format_dict(attrs),
exc_info=True,
) )
return False return False
@ -823,35 +859,42 @@ class LdapClient:
:param protected_attrs: An optional list of protected attributes :param protected_attrs: An optional list of protected attributes
:param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected) :param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected)
""" """
assert isinstance(changes, (list, tuple)) and len(changes) == 2 and isinstance(changes[0], dict) and isinstance(changes[1], dict), f'changes parameter must be a result of get_changes() method ({type(changes)} given)' assert (
isinstance(changes, (list, tuple))
and len(changes) == 2
and isinstance(changes[0], dict)
and isinstance(changes[1], dict)
), f"changes parameter must be a result of get_changes() method ({type(changes)} given)"
# In case of RDN change, we need to modify passed changes, copy it to make it unchanged in # In case of RDN change, we need to modify passed changes, copy it to make it unchanged in
# this case # this case
_changes = copy.deepcopy(changes) _changes = copy.deepcopy(changes)
if not rdn_attr: if not rdn_attr:
rdn_attr = ldap_obj['dn'].split('=')[0] rdn_attr = ldap_obj["dn"].split("=")[0]
log.debug('Auto-detected RDN attribute from DN: %s => %s', ldap_obj['dn'], rdn_attr) log.debug("Auto-detected RDN attribute from DN: %s => %s", ldap_obj["dn"], rdn_attr)
old_rdn_values = self.get_attr(_changes[0], rdn_attr, all_values=True) old_rdn_values = self.get_attr(_changes[0], rdn_attr, all_values=True)
new_rdn_values = self.get_attr(_changes[1], rdn_attr, all_values=True) new_rdn_values = self.get_attr(_changes[1], rdn_attr, all_values=True)
if old_rdn_values or new_rdn_values: if old_rdn_values or new_rdn_values:
if not new_rdn_values: if not new_rdn_values:
log.error( log.error(
"%s : Attribute %s can't be deleted because it's used as RDN.", "%s : Attribute %s can't be deleted because it's used as RDN.",
ldap_obj['dn'], rdn_attr ldap_obj["dn"],
rdn_attr,
) )
return False return False
log.debug( log.debug(
'%s: Changes detected on %s RDN attribute: must rename object before updating it', "%s: Changes detected on %s RDN attribute: must rename object before updating it",
ldap_obj['dn'], rdn_attr ldap_obj["dn"],
rdn_attr,
) )
# Compute new object DN # Compute new object DN
dn_parts = explode_dn(self.decode(ldap_obj['dn'])) dn_parts = explode_dn(self.decode(ldap_obj["dn"]))
basedn = ','.join(dn_parts[1:]) basedn = ",".join(dn_parts[1:])
new_rdn = f'{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}' new_rdn = f"{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}"
new_dn = f'{new_rdn},{basedn}' new_dn = f"{new_rdn},{basedn}"
# Rename object # Rename object
log.debug('%s: Rename to %s', ldap_obj['dn'], new_dn) log.debug("%s: Rename to %s", ldap_obj["dn"], new_dn)
if not self.move_object(ldap_obj, new_dn): if not self.move_object(ldap_obj, new_dn):
return False return False
@ -865,30 +908,29 @@ class LdapClient:
# Check that there are other changes # Check that there are other changes
if not _changes[0] and not _changes[1]: if not _changes[0] and not _changes[1]:
log.debug('%s: No other change after renaming', new_dn) log.debug("%s: No other change after renaming", new_dn)
return True return True
# Otherwise, update object DN # Otherwise, update object DN
ldap_obj['dn'] = new_dn ldap_obj["dn"] = new_dn
else: else:
log.debug('%s: No change detected on RDN attibute %s', ldap_obj['dn'], rdn_attr) log.debug("%s: No change detected on RDN attibute %s", ldap_obj["dn"], rdn_attr)
try: try:
if self._just_try: if self._just_try:
log.debug('Just-try mode : do not really update object in LDAP') log.debug("Just-try mode : do not really update object in LDAP")
return True return True
assert self._conn or self.initialize() assert self._conn or self.initialize()
return self._conn.update_object( return self._conn.update_object(
ldap_obj['dn'], ldap_obj["dn"], _changes[0], _changes[1], ignore_attrs=protected_attrs
_changes[0],
_changes[1],
ignore_attrs=protected_attrs
) )
except LdapServerException: except LdapServerException:
log.error( log.error(
"An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n", "An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n",
ldap_obj['dn'], pretty_format_dict(_changes[0]), pretty_format_dict(_changes[1]), ldap_obj["dn"],
exc_info=True pretty_format_dict(_changes[0]),
pretty_format_dict(_changes[1]),
exc_info=True,
) )
return False return False
@ -901,14 +943,16 @@ class LdapClient:
""" """
try: try:
if self._just_try: if self._just_try:
log.debug('Just-try mode : do not really move object in LDAP') log.debug("Just-try mode : do not really move object in LDAP")
return True return True
assert self._conn or self.initialize() assert self._conn or self.initialize()
return self._conn.rename_object(ldap_obj['dn'], new_dn_or_rdn) return self._conn.rename_object(ldap_obj["dn"], new_dn_or_rdn)
except LdapServerException: except LdapServerException:
log.error( log.error(
"An error occurred moving object %s in LDAP (destination: %s)", "An error occurred moving object %s in LDAP (destination: %s)",
ldap_obj['dn'], new_dn_or_rdn, exc_info=True ldap_obj["dn"],
new_dn_or_rdn,
exc_info=True,
) )
return False return False
@ -920,15 +964,12 @@ class LdapClient:
""" """
try: try:
if self._just_try: if self._just_try:
log.debug('Just-try mode : do not really drop object in LDAP') log.debug("Just-try mode : do not really drop object in LDAP")
return True return True
assert self._conn or self.initialize() assert self._conn or self.initialize()
return self._conn.drop_object(ldap_obj['dn']) return self._conn.drop_object(ldap_obj["dn"])
except LdapServerException: except LdapServerException:
log.error( log.error("An error occurred removing object %s in LDAP", ldap_obj["dn"], exc_info=True)
"An error occurred removing object %s in LDAP",
ldap_obj['dn'], exc_info=True
)
return False return False
@ -943,20 +984,29 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
:param value: The LDAP date string to convert :param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this :param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date string) specific timezone (optional, default : timezone of the LDAP date
string)
:param default_timezone: The timezone used if LDAP date string does not specified :param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone) the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP) :param naive: Use naive datetime : return naive datetime object (without timezone
conversion from LDAP)
""" """
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})' assert to_timezone is None or isinstance(
assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(default_timezone)})' to_timezone, (datetime.tzinfo, str)
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
assert default_timezone is None or isinstance(
default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
), (
"default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
f" datetime.tzinfo object (not {type(default_timezone)})"
)
date = dateutil.parser.parse(value, dayfirst=False) date = dateutil.parser.parse(value, dayfirst=False)
if not date.tzinfo: if not date.tzinfo:
if naive: if naive:
return date return date
if not default_timezone: if not default_timezone:
default_timezone = pytz.utc default_timezone = pytz.utc
elif default_timezone == 'local': elif default_timezone == "local":
default_timezone = dateutil.tz.tzlocal() default_timezone = dateutil.tz.tzlocal()
elif isinstance(default_timezone, str): elif isinstance(default_timezone, str):
default_timezone = pytz.timezone(default_timezone) default_timezone = pytz.timezone(default_timezone)
@ -969,7 +1019,7 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
elif naive: elif naive:
return date.replace(tzinfo=None) return date.replace(tzinfo=None)
if to_timezone: if to_timezone:
if to_timezone == 'local': if to_timezone == "local":
to_timezone = dateutil.tz.tzlocal() to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str): elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone) to_timezone = pytz.timezone(to_timezone)
@ -983,7 +1033,8 @@ def parse_date(value, to_timezone=None, default_timezone=None, naive=True):
:param value: The LDAP date string to convert :param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this :param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date string) specific timezone (optional, default : timezone of the LDAP date
string)
:param default_timezone: The timezone used if LDAP date string does not specified :param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone) the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : do not handle timezone conversion from LDAP :param naive: Use naive datetime : do not handle timezone conversion from LDAP
@ -999,13 +1050,23 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo) :param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone) (optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC) :param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion) :param naive: Use naive datetime : datetime store as UTC in LDAP (without
conversion)
""" """
assert isinstance(value, datetime.datetime), f'First parameter must be an datetime.datetime object (not {type(value)})' assert isinstance(
assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(from_timezone)})' value, datetime.datetime
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})' ), f"First parameter must be an datetime.datetime object (not {type(value)})"
assert from_timezone is None or isinstance(
from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)
), (
"from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a"
f" datetime.tzinfo object (not {type(from_timezone)})"
)
assert to_timezone is None or isinstance(
to_timezone, (datetime.tzinfo, str)
), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})"
if not value.tzinfo and not naive: if not value.tzinfo and not naive:
if not from_timezone or from_timezone == 'local': if not from_timezone or from_timezone == "local":
from_timezone = dateutil.tz.tzlocal() from_timezone = dateutil.tz.tzlocal()
elif isinstance(from_timezone, str): elif isinstance(from_timezone, str):
from_timezone = pytz.timezone(from_timezone) from_timezone = pytz.timezone(from_timezone)
@ -1021,14 +1082,14 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
from_value = copy.deepcopy(value) from_value = copy.deepcopy(value)
if not to_timezone: if not to_timezone:
to_timezone = pytz.utc to_timezone = pytz.utc
elif to_timezone == 'local': elif to_timezone == "local":
to_timezone = dateutil.tz.tzlocal() to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str): elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone) to_timezone = pytz.timezone(to_timezone)
to_value = from_value.astimezone(to_timezone) if not naive else from_value to_value = from_value.astimezone(to_timezone) if not naive else from_value
datestring = to_value.strftime('%Y%m%d%H%M%S%z') datestring = to_value.strftime("%Y%m%d%H%M%S%z")
if datestring.endswith('+0000'): if datestring.endswith("+0000"):
datestring = datestring.replace('+0000', 'Z') datestring = datestring.replace("+0000", "Z")
return datestring return datestring
@ -1040,8 +1101,16 @@ def format_date(value, from_timezone=None, to_timezone=None, naive=True):
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo) :param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone) (optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC) :param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : do not handle timezone conversion before formating :param naive: Use naive datetime : do not handle timezone conversion before
and return datetime as UTC (because LDAP required a timezone) formating and return datetime as UTC (because LDAP required a
timezone)
""" """
assert isinstance(value, datetime.date), f'First parameter must be an datetime.date object (not {type(value)})' assert isinstance(
return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive) value, datetime.date
), f"First parameter must be an datetime.date object (not {type(value)})"
return format_datetime(
datetime.datetime.combine(value, datetime.datetime.min.time()),
from_timezone,
to_timezone,
naive,
)

View file

@ -48,7 +48,6 @@ Return format :
import logging import logging
import re import re
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" MySQL client """ """ MySQL client """
import logging import logging
@ -8,9 +6,7 @@ import sys
import MySQLdb import MySQLdb
from MySQLdb._exceptions import Error from MySQLdb._exceptions import Error
from mylib.db import DB from mylib.db import DB, DBFailToConnect
from mylib.db import DBFailToConnect
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -28,7 +24,7 @@ class MyDB(DB):
self._user = user self._user = user
self._pwd = pwd self._pwd = pwd
self._db = db self._db = db
self._charset = charset if charset else 'utf8' self._charset = charset if charset else "utf8"
super().__init__(**kwargs) super().__init__(**kwargs)
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
@ -36,17 +32,25 @@ class MyDB(DB):
if self._conn is None: if self._conn is None:
try: try:
self._conn = MySQLdb.connect( self._conn = MySQLdb.connect(
host=self._host, user=self._user, passwd=self._pwd, host=self._host,
db=self._db, charset=self._charset, use_unicode=True) user=self._user,
passwd=self._pwd,
db=self._db,
charset=self._charset,
use_unicode=True,
)
except Error as err: except Error as err:
log.fatal( log.fatal(
'An error occured during MySQL database connection (%s@%s:%s).', "An error occured during MySQL database connection (%s@%s:%s).",
self._user, self._host, self._db, exc_info=1 self._user,
self._host,
self._db,
exc_info=1,
) )
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)
else: else:
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err
return True return True
def doSQL(self, sql, params=None): def doSQL(self, sql, params=None):
@ -88,10 +92,7 @@ class MyDB(DB):
cursor = self._conn.cursor() cursor = self._conn.cursor()
cursor.execute(sql, params) cursor.execute(sql, params)
return [ return [
dict( {field[0]: row[idx] for idx, field in enumerate(cursor.description)}
(field[0], row[idx])
for idx, field in enumerate(cursor.description)
)
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
except Error: except Error:
@ -101,13 +102,11 @@ class MyDB(DB):
@staticmethod @staticmethod
def _quote_table_name(table): def _quote_table_name(table):
"""Quote table name""" """Quote table name"""
return '`{0}`'.format( # pylint: disable=consider-using-f-string return "`{}`".format( # pylint: disable=consider-using-f-string
'`.`'.join( "`.`".join(table.split("."))
table.split('.')
)
) )
@staticmethod @staticmethod
def _quote_field_name(field): def _quote_field_name(field):
"""Quote table name""" """Quote table name"""
return f'`{field}`' return f"`{field}`"

View file

@ -1,18 +1,16 @@
# -*- coding: utf-8 -*-
""" Opening hours helpers """ """ Opening hours helpers """
import datetime import datetime
import logging
import re import re
import time import time
import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
week_days = ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi', 'samedi', 'dimanche'] week_days = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"]
date_format = '%d/%m/%Y' date_format = "%d/%m/%Y"
date_pattern = re.compile('^([0-9]{2})/([0-9]{2})/([0-9]{4})$') date_pattern = re.compile("^([0-9]{2})/([0-9]{2})/([0-9]{4})$")
time_pattern = re.compile('^([0-9]{1,2})h([0-9]{2})?$') time_pattern = re.compile("^([0-9]{1,2})h([0-9]{2})?$")
def easter_date(year): def easter_date(year):
@ -41,20 +39,20 @@ def nonworking_french_public_days_of_the_year(year=None):
year = datetime.date.today().year year = datetime.date.today().year
dp = easter_date(year) dp = easter_date(year)
return { return {
'1janvier': datetime.date(year, 1, 1), "1janvier": datetime.date(year, 1, 1),
'paques': dp, "paques": dp,
'lundi_paques': (dp + datetime.timedelta(1)), "lundi_paques": (dp + datetime.timedelta(1)),
'1mai': datetime.date(year, 5, 1), "1mai": datetime.date(year, 5, 1),
'8mai': datetime.date(year, 5, 8), "8mai": datetime.date(year, 5, 8),
'jeudi_ascension': (dp + datetime.timedelta(39)), "jeudi_ascension": (dp + datetime.timedelta(39)),
'pentecote': (dp + datetime.timedelta(49)), "pentecote": (dp + datetime.timedelta(49)),
'lundi_pentecote': (dp + datetime.timedelta(50)), "lundi_pentecote": (dp + datetime.timedelta(50)),
'14juillet': datetime.date(year, 7, 14), "14juillet": datetime.date(year, 7, 14),
'15aout': datetime.date(year, 8, 15), "15aout": datetime.date(year, 8, 15),
'1novembre': datetime.date(year, 11, 1), "1novembre": datetime.date(year, 11, 1),
'11novembre': datetime.date(year, 11, 11), "11novembre": datetime.date(year, 11, 11),
'noel': datetime.date(year, 12, 25), "noel": datetime.date(year, 12, 25),
'saint_etienne': datetime.date(year, 12, 26), "saint_etienne": datetime.date(year, 12, 26),
} }
@ -68,7 +66,7 @@ def parse_exceptional_closures(values):
for word in words: for word in words:
if not word: if not word:
continue continue
parts = word.split('-') parts = word.split("-")
if len(parts) == 1: if len(parts) == 1:
# ex: 31/02/2017 # ex: 31/02/2017
ptime = time.strptime(word, date_format) ptime = time.strptime(word, date_format)
@ -82,7 +80,7 @@ def parse_exceptional_closures(values):
pstart = time.strptime(parts[0], date_format) pstart = time.strptime(parts[0], date_format)
pstop = time.strptime(parts[1], date_format) pstop = time.strptime(parts[1], date_format)
if pstop <= pstart: if pstop <= pstart:
raise ValueError(f'Day {parts[1]} <= {parts[0]}') raise ValueError(f"Day {parts[1]} <= {parts[0]}")
date = datetime.date(pstart.tm_year, pstart.tm_mon, pstart.tm_mday) date = datetime.date(pstart.tm_year, pstart.tm_mon, pstart.tm_mday)
stop_date = datetime.date(pstop.tm_year, pstop.tm_mon, pstop.tm_mday) stop_date = datetime.date(pstop.tm_year, pstop.tm_mon, pstop.tm_mday)
@ -99,13 +97,13 @@ def parse_exceptional_closures(values):
hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0)) hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0))
hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0)) hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0))
if hstop <= hstart: if hstop <= hstart:
raise ValueError(f'Time {parts[1]} <= {parts[0]}') raise ValueError(f"Time {parts[1]} <= {parts[0]}")
hours_periods.append({'start': hstart, 'stop': hstop}) hours_periods.append({"start": hstart, "stop": hstop})
else: else:
raise ValueError(f'Invalid number of part in this word: "{word}"') raise ValueError(f'Invalid number of part in this word: "{word}"')
if not days: if not days:
raise ValueError(f'No days found in value "{value}"') raise ValueError(f'No days found in value "{value}"')
exceptional_closures.append({'days': days, 'hours_periods': hours_periods}) exceptional_closures.append({"days": days, "hours_periods": hours_periods})
return exceptional_closures return exceptional_closures
@ -119,7 +117,7 @@ def parse_normal_opening_hours(values):
for word in words: for word in words:
if not word: if not word:
continue continue
parts = word.split('-') parts = word.split("-")
if len(parts) == 1: if len(parts) == 1:
# ex: jeudi # ex: jeudi
if word not in week_days: if word not in week_days:
@ -150,20 +148,23 @@ def parse_normal_opening_hours(values):
hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0)) hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0))
hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0)) hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0))
if hstop <= hstart: if hstop <= hstart:
raise ValueError(f'Time {parts[1]} <= {parts[0]}') raise ValueError(f"Time {parts[1]} <= {parts[0]}")
hours_periods.append({'start': hstart, 'stop': hstop}) hours_periods.append({"start": hstart, "stop": hstop})
else: else:
raise ValueError(f'Invalid number of part in this word: "{word}"') raise ValueError(f'Invalid number of part in this word: "{word}"')
if not days and not hours_periods: if not days and not hours_periods:
raise ValueError(f'No days or hours period found in this value: "{value}"') raise ValueError(f'No days or hours period found in this value: "{value}"')
normal_opening_hours.append({'days': days, 'hours_periods': hours_periods}) normal_opening_hours.append({"days": days, "hours_periods": hours_periods})
return normal_opening_hours return normal_opening_hours
def is_closed( def is_closed(
normal_opening_hours_values=None, exceptional_closures_values=None, normal_opening_hours_values=None,
nonworking_public_holidays_values=None, exceptional_closure_on_nonworking_public_days=False, exceptional_closures_values=None,
when=None, on_error='raise' nonworking_public_holidays_values=None,
exceptional_closure_on_nonworking_public_days=False,
when=None,
on_error="raise",
): ):
"""Check if closed""" """Check if closed"""
if not when: if not when:
@ -172,18 +173,26 @@ def is_closed(
when_time = when.time() when_time = when.time()
when_weekday = week_days[when.timetuple().tm_wday] when_weekday = week_days[when.timetuple().tm_wday]
on_error_result = None on_error_result = None
if on_error == 'closed': if on_error == "closed":
on_error_result = { on_error_result = {
'closed': True, 'exceptional_closure': False, "closed": True,
'exceptional_closure_all_day': False} "exceptional_closure": False,
elif on_error == 'opened': "exceptional_closure_all_day": False,
}
elif on_error == "opened":
on_error_result = { on_error_result = {
'closed': False, 'exceptional_closure': False, "closed": False,
'exceptional_closure_all_day': False} "exceptional_closure": False,
"exceptional_closure_all_day": False,
}
log.debug( log.debug(
"When = %s => date = %s / time = %s / week day = %s", "When = %s => date = %s / time = %s / week day = %s",
when, when_date, when_time, when_weekday) when,
when_date,
when_time,
when_weekday,
)
if nonworking_public_holidays_values: if nonworking_public_holidays_values:
log.debug("Nonworking public holidays: %s", nonworking_public_holidays_values) log.debug("Nonworking public holidays: %s", nonworking_public_holidays_values)
nonworking_days = nonworking_french_public_days_of_the_year() nonworking_days = nonworking_french_public_days_of_the_year()
@ -191,65 +200,69 @@ def is_closed(
if day in nonworking_days and when_date == nonworking_days[day]: if day in nonworking_days and when_date == nonworking_days[day]:
log.debug("Non working day: %s", day) log.debug("Non working day: %s", day)
return { return {
'closed': True, "closed": True,
'exceptional_closure': exceptional_closure_on_nonworking_public_days, "exceptional_closure": exceptional_closure_on_nonworking_public_days,
'exceptional_closure_all_day': exceptional_closure_on_nonworking_public_days "exceptional_closure_all_day": exceptional_closure_on_nonworking_public_days,
} }
if exceptional_closures_values: if exceptional_closures_values:
try: try:
exceptional_closures = parse_exceptional_closures(exceptional_closures_values) exceptional_closures = parse_exceptional_closures(exceptional_closures_values)
log.debug('Exceptional closures: %s', exceptional_closures) log.debug("Exceptional closures: %s", exceptional_closures)
except ValueError as e: except ValueError as e:
log.error("Fail to parse exceptional closures, consider as closed", exc_info=True) log.error("Fail to parse exceptional closures, consider as closed", exc_info=True)
if on_error_result is None: if on_error_result is None:
raise e from e raise e from e
return on_error_result return on_error_result
for cl in exceptional_closures: for cl in exceptional_closures:
if when_date not in cl['days']: if when_date not in cl["days"]:
log.debug("when_date (%s) no in days (%s)", when_date, cl['days']) log.debug("when_date (%s) no in days (%s)", when_date, cl["days"])
continue continue
if not cl['hours_periods']: if not cl["hours_periods"]:
# All day exceptional closure # All day exceptional closure
return { return {
'closed': True, 'exceptional_closure': True, "closed": True,
'exceptional_closure_all_day': True} "exceptional_closure": True,
for hp in cl['hours_periods']: "exceptional_closure_all_day": True,
if hp['start'] <= when_time <= hp['stop']: }
for hp in cl["hours_periods"]:
if hp["start"] <= when_time <= hp["stop"]:
return { return {
'closed': True, 'exceptional_closure': True, "closed": True,
'exceptional_closure_all_day': False} "exceptional_closure": True,
"exceptional_closure_all_day": False,
}
if normal_opening_hours_values: if normal_opening_hours_values:
try: try:
normal_opening_hours = parse_normal_opening_hours(normal_opening_hours_values) normal_opening_hours = parse_normal_opening_hours(normal_opening_hours_values)
log.debug('Normal opening hours: %s', normal_opening_hours) log.debug("Normal opening hours: %s", normal_opening_hours)
except ValueError as e: # pylint: disable=broad-except except ValueError as e: # pylint: disable=broad-except
log.error("Fail to parse normal opening hours, consider as closed", exc_info=True) log.error("Fail to parse normal opening hours, consider as closed", exc_info=True)
if on_error_result is None: if on_error_result is None:
raise e from e raise e from e
return on_error_result return on_error_result
for oh in normal_opening_hours: for oh in normal_opening_hours:
if oh['days'] and when_weekday not in oh['days']: if oh["days"] and when_weekday not in oh["days"]:
log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh['days']) log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh["days"])
continue continue
if not oh['hours_periods']: if not oh["hours_periods"]:
# All day opened # All day opened
return { return {
'closed': False, 'exceptional_closure': False, "closed": False,
'exceptional_closure_all_day': False} "exceptional_closure": False,
for hp in oh['hours_periods']: "exceptional_closure_all_day": False,
if hp['start'] <= when_time <= hp['stop']: }
for hp in oh["hours_periods"]:
if hp["start"] <= when_time <= hp["stop"]:
return { return {
'closed': False, 'exceptional_closure': False, "closed": False,
'exceptional_closure_all_day': False} "exceptional_closure": False,
"exceptional_closure_all_day": False,
}
log.debug("Not in normal opening hours => closed") log.debug("Not in normal opening hours => closed")
return { return {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
# Not a nonworking day, not during exceptional closure and no normal opening # Not a nonworking day, not during exceptional closure and no normal opening
# hours defined => Opened # hours defined => Opened
return { return {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': False, 'exceptional_closure': False,
'exceptional_closure_all_day': False}

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" Oracle client """ """ Oracle client """
import logging import logging
@ -7,8 +5,7 @@ import sys
import cx_Oracle import cx_Oracle
from mylib.db import DB from mylib.db import DB, DBFailToConnect
from mylib.db import DBFailToConnect
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -29,22 +26,20 @@ class OracleDB(DB):
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
"""Connect to Oracle server""" """Connect to Oracle server"""
if self._conn is None: if self._conn is None:
log.info('Connect on Oracle server with DSN %s as %s', self._dsn, self._user) log.info("Connect on Oracle server with DSN %s as %s", self._dsn, self._user)
try: try:
self._conn = cx_Oracle.connect( self._conn = cx_Oracle.connect(user=self._user, password=self._pwd, dsn=self._dsn)
user=self._user,
password=self._pwd,
dsn=self._dsn
)
except cx_Oracle.Error as err: except cx_Oracle.Error as err:
log.fatal( log.fatal(
'An error occured during Oracle database connection (%s@%s).', "An error occured during Oracle database connection (%s@%s).",
self._user, self._dsn, exc_info=1 self._user,
self._dsn,
exc_info=1,
) )
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)
else: else:
raise DBFailToConnect(f'{self._user}@{self._dsn}') from err raise DBFailToConnect(f"{self._user}@{self._dsn}") from err
return True return True
def doSQL(self, sql, params=None): def doSQL(self, sql, params=None):
@ -108,4 +103,4 @@ class OracleDB(DB):
@staticmethod @staticmethod
def format_param(param): def format_param(param):
"""Format SQL query parameter for prepared query""" """Format SQL query parameter for prepared query"""
return f':{param}' return f":{param}"

View file

@ -1,10 +1,8 @@
# coding: utf8
""" Progress bar """ """ Progress bar """
import logging import logging
import progressbar
import progressbar
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -25,15 +23,15 @@ class Pbar: # pylint: disable=useless-object-inheritance
self.__count = 0 self.__count = 0
self.__pbar = progressbar.ProgressBar( self.__pbar = progressbar.ProgressBar(
widgets=[ widgets=[
name + ': ', name + ": ",
progressbar.Percentage(), progressbar.Percentage(),
' ', " ",
progressbar.Bar(), progressbar.Bar(),
' ', " ",
progressbar.SimpleProgress(), progressbar.SimpleProgress(),
progressbar.ETA() progressbar.ETA(),
], ],
maxval=maxval maxval=maxval,
).start() ).start()
else: else:
log.info(name) log.info(name)

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" PostgreSQL client """ """ PostgreSQL client """
import datetime import datetime
@ -21,8 +19,8 @@ class PgDB(DB):
_pwd = None _pwd = None
_db = None _db = None
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
datetime_format = '%Y-%m-%d %H:%M:%S' datetime_format = "%Y-%m-%d %H:%M:%S"
def __init__(self, host, user, pwd, db, **kwargs): def __init__(self, host, user, pwd, db, **kwargs):
self._host = host self._host = host
@ -36,23 +34,26 @@ class PgDB(DB):
if self._conn is None: if self._conn is None:
try: try:
log.info( log.info(
'Connect on PostgreSQL server %s as %s on database %s', "Connect on PostgreSQL server %s as %s on database %s",
self._host, self._user, self._db) self._host,
self._user,
self._db,
)
self._conn = psycopg2.connect( self._conn = psycopg2.connect(
dbname=self._db, dbname=self._db, user=self._user, host=self._host, password=self._pwd
user=self._user,
host=self._host,
password=self._pwd
) )
except psycopg2.Error as err: except psycopg2.Error as err:
log.fatal( log.fatal(
'An error occured during Postgresql database connection (%s@%s, database=%s).', "An error occured during Postgresql database connection (%s@%s, database=%s).",
self._user, self._host, self._db, exc_info=1 self._user,
self._host,
self._db,
exc_info=1,
) )
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)
else: else:
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err
return True return True
def close(self): def close(self):
@ -70,7 +71,8 @@ class PgDB(DB):
except psycopg2.Error: except psycopg2.Error:
log.error( log.error(
'An error occured setting Postgresql database connection encoding to "%s"', 'An error occured setting Postgresql database connection encoding to "%s"',
enc, exc_info=1 enc,
exc_info=1,
) )
return False return False
@ -124,10 +126,7 @@ class PgDB(DB):
@staticmethod @staticmethod
def _map_row_fields_by_index(fields, row): def _map_row_fields_by_index(fields, row):
return dict( return {field: row[idx] for idx, field in enumerate(fields)}
(field, row[idx])
for idx, field in enumerate(fields)
)
# #
# Depreated helpers # Depreated helpers
@ -137,7 +136,7 @@ class PgDB(DB):
def _quote_value(cls, value): def _quote_value(cls, value):
"""Quote a value for SQL query""" """Quote a value for SQL query"""
if value is None: if value is None:
return 'NULL' return "NULL"
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
return str(value) return str(value)
@ -148,7 +147,7 @@ class PgDB(DB):
value = cls._format_date(value) value = cls._format_date(value)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
return "'{0}'".format(value.replace("'", "''")) return "'{}'".format(value.replace("'", "''"))
@classmethod @classmethod
def _format_datetime(cls, value): def _format_datetime(cls, value):

View file

@ -1,28 +1,24 @@
# coding: utf8
""" Report """ """ Report """
import atexit import atexit
import logging import logging
from mylib.config import ConfigurableObject from mylib.config import ConfigurableObject, StringOption
from mylib.config import StringOption
from mylib.email import EmailClient from mylib.email import EmailClient
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
"""Logging report""" """Logging report"""
_config_name = 'report' _config_name = "report"
_config_comment = 'Email report' _config_comment = "Email report"
_defaults = { _defaults = {
'recipient': None, "recipient": None,
'subject': 'Report', "subject": "Report",
'loglevel': 'WARNING', "loglevel": "WARNING",
'logformat': '%(asctime)s - %(levelname)s - %(message)s', "logformat": "%(asctime)s - %(levelname)s - %(message)s",
} }
content = [] content = []
@ -43,17 +39,25 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
"""Configure options on registered mylib.Config object""" """Configure options on registered mylib.Config object"""
section = super().configure(**kwargs) section = super().configure(**kwargs)
section.add_option(StringOption, "recipient", comment="Report recipient email address")
section.add_option( section.add_option(
StringOption, 'recipient', comment='Report recipient email address') StringOption,
"subject",
default=self._defaults["subject"],
comment="Report email subject",
)
section.add_option( section.add_option(
StringOption, 'subject', default=self._defaults['subject'], StringOption,
comment='Report email subject') "loglevel",
default=self._defaults["loglevel"],
comment='Report log level (as accept by python logging, for instance "INFO")',
)
section.add_option( section.add_option(
StringOption, 'loglevel', default=self._defaults['loglevel'], StringOption,
comment='Report log level (as accept by python logging, for instance "INFO")') "logformat",
section.add_option( default=self._defaults["logformat"],
StringOption, 'logformat', default=self._defaults['logformat'], comment='Report log level (as accept by python logging, for instance "INFO")',
comment='Report log level (as accept by python logging, for instance "INFO")') )
if not self.email_client: if not self.email_client:
self.email_client = EmailClient(config=self._config) self.email_client = EmailClient(config=self._config)
@ -66,12 +70,11 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
super().initialize(loaded_config=loaded_config) super().initialize(loaded_config=loaded_config)
self.handler = logging.StreamHandler(self) self.handler = logging.StreamHandler(self)
loglevel = self._get_option('loglevel').upper() loglevel = self._get_option("loglevel").upper()
assert hasattr(logging, loglevel), ( assert hasattr(logging, loglevel), f"Invalid report loglevel {loglevel}"
f'Invalid report loglevel {loglevel}')
self.handler.setLevel(getattr(logging, loglevel)) self.handler.setLevel(getattr(logging, loglevel))
self.formatter = logging.Formatter(self._get_option('logformat')) self.formatter = logging.Formatter(self._get_option("logformat"))
self.handler.setFormatter(self.formatter) self.handler.setFormatter(self.formatter)
def get_handler(self): def get_handler(self):
@ -97,29 +100,34 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
def send(self, subject=None, rcpt_to=None, email_client=None, just_try=False): def send(self, subject=None, rcpt_to=None, email_client=None, just_try=False):
"""Send report using an EmailClient""" """Send report using an EmailClient"""
if rcpt_to is None: if rcpt_to is None:
rcpt_to = self._get_option('recipient') rcpt_to = self._get_option("recipient")
if not rcpt_to: if not rcpt_to:
log.debug('No report recipient, do not send report') log.debug("No report recipient, do not send report")
return True return True
if subject is None: if subject is None:
subject = self._get_option('subject') subject = self._get_option("subject")
assert subject, "You must provide report subject using Report.__init__ or Report.send" assert subject, "You must provide report subject using Report.__init__ or Report.send"
if email_client is None: if email_client is None:
email_client = self.email_client email_client = self.email_client
assert email_client, ( assert email_client, (
"You must provide an email client __init__(), send() or send_at_exit() methods argument email_client") "You must provide an email client __init__(), send() or send_at_exit() methods argument"
" email_client"
)
content = self.get_content() content = self.get_content()
if not content: if not content:
log.debug('Report is empty, do not send it') log.debug("Report is empty, do not send it")
return True return True
msg = email_client.forge_message( msg = email_client.forge_message(
rcpt_to, subject=subject, text_body=content, rcpt_to,
subject=subject,
text_body=content,
attachment_files=self._attachment_files, attachment_files=self._attachment_files,
attachment_payloads=self._attachment_payloads) attachment_payloads=self._attachment_payloads,
)
if email_client.send(rcpt_to, msg=msg, just_try=just_try): if email_client.send(rcpt_to, msg=msg, just_try=just_try):
log.debug('Report sent to %s', rcpt_to) log.debug("Report sent to %s", rcpt_to)
return True return True
log.error('Fail to send report to %s', rcpt_to) log.error("Fail to send report to %s", rcpt_to)
return False return False
def send_at_exit(self, **kwargs): def send_at_exit(self, **kwargs):

View file

@ -1,18 +1,14 @@
# -*- coding: utf-8 -*-
""" Test Email client """ """ Test Email client """
import datetime import datetime
import getpass
import logging import logging
import sys import sys
import getpass
from mako.template import Template as MakoTemplate from mako.template import Template as MakoTemplate
from mylib.scripts.helpers import get_opts_parser, add_email_opts from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging
from mylib.scripts.helpers import init_logging, init_email_client
log = logging.getLogger("mylib.scripts.email_test")
log = logging.getLogger('mylib.scripts.email_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -24,10 +20,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True) parser = get_opts_parser(just_try=True)
add_email_opts(parser) add_email_opts(parser)
test_opts = parser.add_argument_group('Test email options') test_opts = parser.add_argument_group("Test email options")
test_opts.add_argument( test_opts.add_argument(
'-t', '--to', "-t",
"--to",
action="store", action="store",
type=str, type=str,
dest="test_to", dest="test_to",
@ -35,7 +32,8 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
) )
test_opts.add_argument( test_opts.add_argument(
'-m', '--mako', "-m",
"--mako",
action="store_true", action="store_true",
dest="test_mako", dest="test_mako",
help="Test mako templating", help="Test mako templating",
@ -44,14 +42,14 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args() options = parser.parse_args()
if not options.test_to: if not options.test_to:
parser.error('You must specify test email recipient using -t/--to parameter') parser.error("You must specify test email recipient using -t/--to parameter")
sys.exit(1) sys.exit(1)
# Initialize logs # Initialize logs
init_logging(options, 'Test EmailClient') init_logging(options, "Test EmailClient")
if options.email_smtp_user and not options.email_smtp_password: if options.email_smtp_user and not options.email_smtp_password:
options.email_smtp_password = getpass.getpass('Please enter SMTP password: ') options.email_smtp_password = getpass.getpass("Please enter SMTP password: ")
email_client = init_email_client( email_client = init_email_client(
options, options,
@ -59,20 +57,24 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
test=dict( test=dict(
subject="Test email", subject="Test email",
text=( text=(
"Just a test email sent at {sent_date}." if not options.test_mako else "Just a test email sent at {sent_date}."
MakoTemplate("Just a test email sent at ${sent_date}.") if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
), ),
html=( html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else "<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>") if not options.test_mako
) else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
) )
),
) )
),
) )
log.info('Send a test email to %s', options.test_to) log.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
log.info('Test email sent') log.info("Test email sent")
sys.exit(0) sys.exit(0)
log.error('Fail to send test email') log.error("Fail to send test email")
sys.exit(1) sys.exit(1)

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" Test Email client using mylib.config.Config for configuration """ """ Test Email client using mylib.config.Config for configuration """
import datetime import datetime
import logging import logging
@ -10,7 +8,6 @@ from mako.template import Template as MakoTemplate
from mylib.config import Config from mylib.config import Config
from mylib.email import EmailClient from mylib.email import EmailClient
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -19,7 +16,7 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
config = Config(__doc__, __name__.replace('.', '_')) config = Config(__doc__, __name__.replace(".", "_"))
email_client = EmailClient(config=config) email_client = EmailClient(config=config)
email_client.configure() email_client.configure()
@ -27,10 +24,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
# Options parser # Options parser
parser = config.get_arguments_parser(description=__doc__) parser = config.get_arguments_parser(description=__doc__)
test_opts = parser.add_argument_group('Test email options') test_opts = parser.add_argument_group("Test email options")
test_opts.add_argument( test_opts.add_argument(
'-t', '--to', "-t",
"--to",
action="store", action="store",
type=str, type=str,
dest="test_to", dest="test_to",
@ -38,7 +36,8 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
) )
test_opts.add_argument( test_opts.add_argument(
'-m', '--mako', "-m",
"--mako",
action="store_true", action="store_true",
dest="test_mako", dest="test_mako",
help="Test mako templating", help="Test mako templating",
@ -47,26 +46,30 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = config.parse_arguments_options() options = config.parse_arguments_options()
if not options.test_to: if not options.test_to:
parser.error('You must specify test email recipient using -t/--to parameter') parser.error("You must specify test email recipient using -t/--to parameter")
sys.exit(1) sys.exit(1)
email_client.templates = dict( email_client.templates = dict(
test=dict( test=dict(
subject="Test email", subject="Test email",
text=( text=(
"Just a test email sent at {sent_date}." if not options.test_mako else "Just a test email sent at {sent_date}."
MakoTemplate("Just a test email sent at ${sent_date}.") if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
), ),
html=( html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else "<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>") if not options.test_mako
else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
) )
),
) )
) )
logging.info('Send a test email to %s', options.test_to) logging.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
logging.info('Test email sent') logging.info("Test email sent")
sys.exit(0) sys.exit(0)
logging.error('Fail to send test email') logging.error("Fail to send test email")
sys.exit(1) sys.exit(1)

View file

@ -1,20 +1,18 @@
# coding: utf8
""" Scripts helpers """ """ Scripts helpers """
import argparse import argparse
import getpass import getpass
import logging import logging
import os.path
import socket import socket
import sys import sys
import os.path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def init_logging(options, name, report=None): def init_logging(options, name, report=None):
"""Initialize logging from calling script options""" """Initialize logging from calling script options"""
logformat = f'%(asctime)s - {name} - %(levelname)s - %(message)s' logformat = f"%(asctime)s - {name} - %(levelname)s - %(message)s"
if options.debug: if options.debug:
loglevel = logging.DEBUG loglevel = logging.DEBUG
elif options.verbose: elif options.verbose:
@ -46,59 +44,44 @@ def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, c
parser = argparse.ArgumentParser(description=desc) parser = argparse.ArgumentParser(description=desc)
parser.add_argument( parser.add_argument(
'-v', '--verbose', "-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode"
action="store_true",
dest="verbose",
help="Enable verbose mode"
) )
parser.add_argument( parser.add_argument(
'-d', '--debug', "-d", "--debug", action="store_true", dest="debug", help="Enable debug mode"
action="store_true",
dest="debug",
help="Enable debug mode"
) )
parser.add_argument( parser.add_argument(
'-l', '--log-file', "-l",
"--log-file",
action="store", action="store",
type=str, type=str,
dest="logfile", dest="logfile",
help=( help=f'Log file path (default: {get_default_opt_value(config, default_config, "logfile")})',
'Log file path (default: ' default=get_default_opt_value(config, default_config, "logfile"),
f'{get_default_opt_value(config, default_config, "logfile")})'),
default=get_default_opt_value(config, default_config, 'logfile')
) )
parser.add_argument( parser.add_argument(
'-C', '--console', "-C",
"--console",
action="store_true", action="store_true",
dest="console", dest="console",
help="Always log on console (even if log file is configured)" help="Always log on console (even if log file is configured)",
) )
if just_try: if just_try:
parser.add_argument( parser.add_argument(
'-j', '--just-try', "-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode"
action="store_true",
dest="just_try",
help="Enable just-try mode"
) )
if just_one: if just_one:
parser.add_argument( parser.add_argument(
'-J', '--just-one', "-J", "--just-one", action="store_true", dest="just_one", help="Enable just-one mode"
action="store_true",
dest="just_one",
help="Enable just-one mode"
) )
if progress: if progress:
parser.add_argument( parser.add_argument(
'-p', '--progress', "-p", "--progress", action="store_true", dest="progress", help="Enable progress bar"
action="store_true",
dest="progress",
help="Enable progress bar"
) )
return parser return parser
@ -106,120 +89,143 @@ def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, c
def add_email_opts(parser, config=None): def add_email_opts(parser, config=None):
"""Add email options""" """Add email options"""
email_opts = parser.add_argument_group('Email options') email_opts = parser.add_argument_group("Email options")
default_config = dict( default_config = dict(
smtp_host="127.0.0.1", smtp_port=25, smtp_ssl=False, smtp_tls=False, smtp_user=None, smtp_host="127.0.0.1",
smtp_password=None, smtp_debug=False, email_encoding=sys.getdefaultencoding(), smtp_port=25,
sender_name=getpass.getuser(), sender_email=f'{getpass.getuser()}@{socket.gethostname()}', smtp_ssl=False,
catch_all=False smtp_tls=False,
smtp_user=None,
smtp_password=None,
smtp_debug=False,
email_encoding=sys.getdefaultencoding(),
sender_name=getpass.getuser(),
sender_email=f"{getpass.getuser()}@{socket.gethostname()}",
catch_all=False,
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-host', "--smtp-host",
action="store", action="store",
type=str, type=str,
dest="email_smtp_host", dest="email_smtp_host",
help=( help=f'SMTP host (default: {get_default_opt_value(config, default_config, "smtp_host")})',
'SMTP host (default: ' default=get_default_opt_value(config, default_config, "smtp_host"),
f'{get_default_opt_value(config, default_config, "smtp_host")})'),
default=get_default_opt_value(config, default_config, 'smtp_host')
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-port', "--smtp-port",
action="store", action="store",
type=int, type=int,
dest="email_smtp_port", dest="email_smtp_port",
help=f'SMTP port (default: {get_default_opt_value(config, default_config, "smtp_port")})', help=f'SMTP port (default: {get_default_opt_value(config, default_config, "smtp_port")})',
default=get_default_opt_value(config, default_config, 'smtp_port') default=get_default_opt_value(config, default_config, "smtp_port"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-ssl', "--smtp-ssl",
action="store_true", action="store_true",
dest="email_smtp_ssl", dest="email_smtp_ssl",
help=f'Use SSL (default: {get_default_opt_value(config, default_config, "smtp_ssl")})', help=f'Use SSL (default: {get_default_opt_value(config, default_config, "smtp_ssl")})',
default=get_default_opt_value(config, default_config, 'smtp_ssl') default=get_default_opt_value(config, default_config, "smtp_ssl"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-tls', "--smtp-tls",
action="store_true", action="store_true",
dest="email_smtp_tls", dest="email_smtp_tls",
help=f'Use TLS (default: {get_default_opt_value(config, default_config, "smtp_tls")})', help=f'Use TLS (default: {get_default_opt_value(config, default_config, "smtp_tls")})',
default=get_default_opt_value(config, default_config, 'smtp_tls') default=get_default_opt_value(config, default_config, "smtp_tls"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-user', "--smtp-user",
action="store", action="store",
type=str, type=str,
dest="email_smtp_user", dest="email_smtp_user",
help=f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})', help=(
default=get_default_opt_value(config, default_config, 'smtp_user') f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})'
),
default=get_default_opt_value(config, default_config, "smtp_user"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-password', "--smtp-password",
action="store", action="store",
type=str, type=str,
dest="email_smtp_password", dest="email_smtp_password",
help=f'SMTP password (default: {get_default_opt_value(config, default_config, "smtp_password")})', help=(
default=get_default_opt_value(config, default_config, 'smtp_password') "SMTP password (default:"
f' {get_default_opt_value(config, default_config, "smtp_password")})'
),
default=get_default_opt_value(config, default_config, "smtp_password"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-debug', "--smtp-debug",
action="store_true", action="store_true",
dest="email_smtp_debug", dest="email_smtp_debug",
help=f'Debug SMTP connection (default: {get_default_opt_value(config, default_config, "smtp_debug")})', help=(
default=get_default_opt_value(config, default_config, 'smtp_debug') "Debug SMTP connection (default:"
f' {get_default_opt_value(config, default_config, "smtp_debug")})'
),
default=get_default_opt_value(config, default_config, "smtp_debug"),
) )
email_opts.add_argument( email_opts.add_argument(
'--email-encoding', "--email-encoding",
action="store", action="store",
type=str, type=str,
dest="email_encoding", dest="email_encoding",
help=f'SMTP encoding (default: {get_default_opt_value(config, default_config, "email_encoding")})', help=(
default=get_default_opt_value(config, default_config, 'email_encoding') "SMTP encoding (default:"
f' {get_default_opt_value(config, default_config, "email_encoding")})'
),
default=get_default_opt_value(config, default_config, "email_encoding"),
) )
email_opts.add_argument( email_opts.add_argument(
'--sender-name', "--sender-name",
action="store", action="store",
type=str, type=str,
dest="email_sender_name", dest="email_sender_name",
help=f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})', help=(
default=get_default_opt_value(config, default_config, 'sender_name') f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})'
),
default=get_default_opt_value(config, default_config, "sender_name"),
) )
email_opts.add_argument( email_opts.add_argument(
'--sender-email', "--sender-email",
action="store", action="store",
type=str, type=str,
dest="email_sender_email", dest="email_sender_email",
help=f'Sender email (default: {get_default_opt_value(config, default_config, "sender_email")})', help=(
default=get_default_opt_value(config, default_config, 'sender_email') "Sender email (default:"
f' {get_default_opt_value(config, default_config, "sender_email")})'
),
default=get_default_opt_value(config, default_config, "sender_email"),
) )
email_opts.add_argument( email_opts.add_argument(
'--catch-all', "--catch-all",
action="store", action="store",
type=str, type=str,
dest="email_catch_all", dest="email_catch_all",
help=( help=(
'Catch all sent email: specify catch recipient email address ' "Catch all sent email: specify catch recipient email address "
f'(default: {get_default_opt_value(config, default_config, "catch_all")})'), f'(default: {get_default_opt_value(config, default_config, "catch_all")})'
default=get_default_opt_value(config, default_config, 'catch_all') ),
default=get_default_opt_value(config, default_config, "catch_all"),
) )
def init_email_client(options, **kwargs): def init_email_client(options, **kwargs):
"""Initialize email client from calling script options""" """Initialize email client from calling script options"""
from mylib.email import EmailClient # pylint: disable=import-outside-toplevel from mylib.email import EmailClient # pylint: disable=import-outside-toplevel
log.info('Initialize Email client')
log.info("Initialize Email client")
return EmailClient( return EmailClient(
smtp_host=options.email_smtp_host, smtp_host=options.email_smtp_host,
smtp_port=options.email_smtp_port, smtp_port=options.email_smtp_port,
@ -231,9 +237,9 @@ def init_email_client(options, **kwargs):
sender_name=options.email_sender_name, sender_name=options.email_sender_name,
sender_email=options.email_sender_email, sender_email=options.email_sender_email,
catch_all_addr=options.email_catch_all, catch_all_addr=options.email_catch_all,
just_try=options.just_try if hasattr(options, 'just_try') else False, just_try=options.just_try if hasattr(options, "just_try") else False,
encoding=options.email_encoding, encoding=options.email_encoding,
**kwargs **kwargs,
) )
@ -242,53 +248,51 @@ def add_sftp_opts(parser):
sftp_opts = parser.add_argument_group("SFTP options") sftp_opts = parser.add_argument_group("SFTP options")
sftp_opts.add_argument( sftp_opts.add_argument(
'-H', '--sftp-host', "-H",
"--sftp-host",
action="store", action="store",
type=str, type=str,
dest="sftp_host", dest="sftp_host",
help="SFTP Host (default: localhost)", help="SFTP Host (default: localhost)",
default='localhost' default="localhost",
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'--sftp-port', "--sftp-port",
action="store", action="store",
type=int, type=int,
dest="sftp_port", dest="sftp_port",
help="SFTP Port (default: 22)", help="SFTP Port (default: 22)",
default=22 default=22,
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'-u', '--sftp-user', "-u", "--sftp-user", action="store", type=str, dest="sftp_user", help="SFTP User"
action="store",
type=str,
dest="sftp_user",
help="SFTP User"
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'-P', '--sftp-password', "-P",
"--sftp-password",
action="store", action="store",
type=str, type=str,
dest="sftp_password", dest="sftp_password",
help="SFTP Password" help="SFTP Password",
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'--sftp-known-hosts', "--sftp-known-hosts",
action="store", action="store",
type=str, type=str,
dest="sftp_known_hosts", dest="sftp_known_hosts",
help="SFTP known_hosts file path (default: ~/.ssh/known_hosts)", help="SFTP known_hosts file path (default: ~/.ssh/known_hosts)",
default=os.path.expanduser('~/.ssh/known_hosts') default=os.path.expanduser("~/.ssh/known_hosts"),
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'--sftp-auto-add-unknown-host-key', "--sftp-auto-add-unknown-host-key",
action="store_true", action="store_true",
dest="sftp_auto_add_unknown_host_key", dest="sftp_auto_add_unknown_host_key",
help="Auto-add unknown SSH host key" help="Auto-add unknown SSH host key",
) )
return sftp_opts return sftp_opts

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" Test LDAP """ """ Test LDAP """
import datetime import datetime
import logging import logging
@ -8,12 +6,10 @@ import sys
import dateutil.tz import dateutil.tz
import pytz import pytz
from mylib.ldap import format_datetime, format_date, parse_datetime, parse_date from mylib.ldap import format_date, format_datetime, parse_date, parse_datetime
from mylib.scripts.helpers import get_opts_parser from mylib.scripts.helpers import get_opts_parser, init_logging
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.ldap_test")
log = logging.getLogger('mylib.scripts.ldap_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -26,52 +22,121 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args() options = parser.parse_args()
# Initialize logs # Initialize logs
init_logging(options, 'Test LDAP helpers') init_logging(options, "Test LDAP helpers")
now = datetime.datetime.now().replace(tzinfo=dateutil.tz.tzlocal()) now = datetime.datetime.now().replace(tzinfo=dateutil.tz.tzlocal())
print(f'Now = {now}') print(f"Now = {now}")
datestring_now = format_datetime(now) datestring_now = format_datetime(now)
print(f'format_datetime : {datestring_now}') print(f"format_datetime : {datestring_now}")
print(f'format_datetime (from_timezone=utc) : {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}') print(
print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}') "format_datetime (from_timezone=utc) :"
print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone="local")}') f" {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}"
print(f'format_datetime (from_timezone=Paris) : {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}') )
print(f'format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}') print(
print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}') "format_datetime (from_timezone=local) :"
f" {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}"
)
print(
"format_datetime (from_timezone=local) :"
f' {format_datetime(now.replace(tzinfo=None), from_timezone="local")}'
)
print(
"format_datetime (from_timezone=Paris) :"
f' {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}'
)
print(f"format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}")
print(
"format_datetime (to_timezone=local) :"
f" {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone="local")}') print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone="local")}')
print(f'format_datetime (to_timezone=Tokyo) : {format_datetime(now, to_timezone="Asia/Tokyo")}') print(f'format_datetime (to_timezone=Tokyo) : {format_datetime(now, to_timezone="Asia/Tokyo")}')
print(f'format_datetime (naive=True) : {format_datetime(now, naive=True)}') print(f"format_datetime (naive=True) : {format_datetime(now, naive=True)}")
print(f'format_date : {format_date(now)}') print(f"format_date : {format_date(now)}")
print(f'format_date (from_timezone=utc) : {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}') print(
print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}') "format_date (from_timezone=utc) :"
print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone="local")}') f" {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}"
print(f'format_date (from_timezone=Paris) : {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}') )
print(f'format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}') print(
print(f'format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}') "format_date (from_timezone=local) :"
f" {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}"
)
print(
"format_date (from_timezone=local) :"
f' {format_date(now.replace(tzinfo=None), from_timezone="local")}'
)
print(
"format_date (from_timezone=Paris) :"
f' {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}'
)
print(f"format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}")
print(
f"format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'format_date (to_timezone=local) : {format_date(now, to_timezone="local")}') print(f'format_date (to_timezone=local) : {format_date(now, to_timezone="local")}')
print(f'format_date (to_timezone=Tokyo) : {format_date(now, to_timezone="Asia/Tokyo")}') print(f'format_date (to_timezone=Tokyo) : {format_date(now, to_timezone="Asia/Tokyo")}')
print(f'format_date (naive=True) : {format_date(now, naive=True)}') print(f"format_date (naive=True) : {format_date(now, naive=True)}")
print(f'parse_datetime : {parse_datetime(datestring_now)}') print(f"parse_datetime : {parse_datetime(datestring_now)}")
print(f'parse_datetime (default_timezone=utc) : {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}') print(
print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}') "parse_datetime (default_timezone=utc) :"
print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone="local")}') f" {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}"
print(f'parse_datetime (default_timezone=Paris) : {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}') )
print(f'parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}') print(
print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}') "parse_datetime (default_timezone=local) :"
print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone="local")}') f" {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}"
print(f'parse_datetime (to_timezone=Tokyo) : {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}') )
print(f'parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}') print(
"parse_datetime (default_timezone=local) :"
f' {parse_datetime(datestring_now[0:-1], default_timezone="local")}'
)
print(
"parse_datetime (default_timezone=Paris) :"
f' {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}'
)
print(
f"parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}"
)
print(
"parse_datetime (to_timezone=local) :"
f" {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_datetime (to_timezone=local) :"
f' {parse_datetime(datestring_now, to_timezone="local")}'
)
print(
"parse_datetime (to_timezone=Tokyo) :"
f' {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}'
)
print(f"parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}")
print(f'parse_date : {parse_date(datestring_now)}') print(f"parse_date : {parse_date(datestring_now)}")
print(f'parse_date (default_timezone=utc) : {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}') print(
print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}') "parse_date (default_timezone=utc) :"
print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone="local")}') f" {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}"
print(f'parse_date (default_timezone=Paris) : {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}') )
print(f'parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}') print(
print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}') "parse_date (default_timezone=local) :"
f" {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_date (default_timezone=local) :"
f' {parse_date(datestring_now[0:-1], default_timezone="local")}'
)
print(
"parse_date (default_timezone=Paris) :"
f' {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}'
)
print(f"parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}")
print(
"parse_date (to_timezone=local) :"
f" {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone="local")}') print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone="local")}')
print(f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}') print(
print(f'parse_date (naive=True) : {parse_date(datestring_now, naive=True)}') f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}'
)
print(f"parse_date (naive=True) : {parse_date(datestring_now, naive=True)}")

View file

@ -64,6 +64,6 @@ def main(argv=None):
"mail": {"order": 12, "key": "email", "convert": lambda x: x.lower().strip()}, "mail": {"order": 12, "key": "email", "convert": lambda x: x.lower().strip()},
} }
print('Mapping source:\n' + pretty_format_value(src)) print("Mapping source:\n" + pretty_format_value(src))
print('Mapping config:\n' + pretty_format_value(map_c)) print("Mapping config:\n" + pretty_format_value(map_c))
print('Mapping result:\n' + pretty_format_value(map_hash(map_c, src))) print("Mapping result:\n" + pretty_format_value(map_hash(map_c, src)))

View file

@ -1,16 +1,12 @@
# -*- coding: utf-8 -*-
""" Test Progress bar """ """ Test Progress bar """
import logging import logging
import time
import sys import sys
import time
from mylib.pbar import Pbar from mylib.pbar import Pbar
from mylib.scripts.helpers import get_opts_parser from mylib.scripts.helpers import get_opts_parser, init_logging
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.pbar_test")
log = logging.getLogger('mylib.scripts.pbar_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -23,20 +19,21 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(progress=True) parser = get_opts_parser(progress=True)
parser.add_argument( parser.add_argument(
'-c', '--count', "-c",
"--count",
action="store", action="store",
type=int, type=int,
dest="count", dest="count",
help=f'Progress bar max value (default: {default_max_val})', help=f"Progress bar max value (default: {default_max_val})",
default=default_max_val default=default_max_val,
) )
options = parser.parse_args() options = parser.parse_args()
# Initialize logs # Initialize logs
init_logging(options, 'Test Pbar') init_logging(options, "Test Pbar")
pbar = Pbar('Test', options.count, enabled=options.progress) pbar = Pbar("Test", options.count, enabled=options.progress)
for idx in range(0, options.count): # pylint: disable=unused-variable for idx in range(0, options.count): # pylint: disable=unused-variable
pbar.increment() pbar.increment()

View file

@ -1,15 +1,11 @@
# -*- coding: utf-8 -*-
""" Test report """ """ Test report """
import logging import logging
import sys import sys
from mylib.report import Report from mylib.report import Report
from mylib.scripts.helpers import get_opts_parser, add_email_opts from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging
from mylib.scripts.helpers import init_logging, init_email_client
log = logging.getLogger("mylib.scripts.report_test")
log = logging.getLogger('mylib.scripts.report_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -21,14 +17,10 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True) parser = get_opts_parser(just_try=True)
add_email_opts(parser) add_email_opts(parser)
report_opts = parser.add_argument_group('Report options') report_opts = parser.add_argument_group("Report options")
report_opts.add_argument( report_opts.add_argument(
'-t', '--to', "-t", "--to", action="store", type=str, dest="report_rcpt", help="Send report to this email"
action="store",
type=str,
dest="report_rcpt",
help="Send report to this email"
) )
options = parser.parse_args() options = parser.parse_args()
@ -37,13 +29,13 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser.error("You must specify a report recipient using -t/--to parameter") parser.error("You must specify a report recipient using -t/--to parameter")
# Initialize logs # Initialize logs
report = Report(rcpt_to=options.report_rcpt, subject='Test report') report = Report(rcpt_to=options.report_rcpt, subject="Test report")
init_logging(options, 'Test Report', report=report) init_logging(options, "Test Report", report=report)
email_client = init_email_client(options) email_client = init_email_client(options)
report.send_at_exit(email_client=email_client) report.send_at_exit(email_client=email_client)
logging.debug('Test debug message') logging.debug("Test debug message")
logging.info('Test info message') logging.info("Test info message")
logging.warning('Test warning message') logging.warning("Test warning message")
logging.error('Test error message') logging.error("Test error message")

View file

@ -1,22 +1,17 @@
# -*- coding: utf-8 -*-
""" Test SFTP client """ """ Test SFTP client """
import atexit import atexit
import tempfile import getpass
import logging import logging
import sys
import os import os
import random import random
import string import string
import sys
import tempfile
import getpass from mylib.scripts.helpers import add_sftp_opts, get_opts_parser, init_logging
from mylib.sftp import SFTPClient from mylib.sftp import SFTPClient
from mylib.scripts.helpers import get_opts_parser, add_sftp_opts
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.sftp_test")
log = logging.getLogger('mylib.scripts.sftp_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
@ -28,10 +23,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True) parser = get_opts_parser(just_try=True)
add_sftp_opts(parser) add_sftp_opts(parser)
test_opts = parser.add_argument_group('Test SFTP options') test_opts = parser.add_argument_group("Test SFTP options")
test_opts.add_argument( test_opts.add_argument(
'-p', '--remote-upload-path', "-p",
"--remote-upload-path",
action="store", action="store",
type=str, type=str,
dest="upload_path", dest="upload_path",
@ -41,66 +37,68 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args() options = parser.parse_args()
# Initialize logs # Initialize logs
init_logging(options, 'Test SFTP client') init_logging(options, "Test SFTP client")
if options.sftp_user and not options.sftp_password: if options.sftp_user and not options.sftp_password:
options.sftp_password = getpass.getpass('Please enter SFTP password: ') options.sftp_password = getpass.getpass("Please enter SFTP password: ")
log.info('Initialize Email client') log.info("Initialize Email client")
sftp = SFTPClient(options=options, just_try=options.just_try) sftp = SFTPClient(options=options, just_try=options.just_try)
sftp.connect() sftp.connect()
atexit.register(sftp.close) atexit.register(sftp.close)
log.debug('Create tempory file') log.debug("Create tempory file")
test_content = b'Juste un test.' test_content = b"Juste un test."
tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
tmp_file = os.path.join( tmp_file = os.path.join(
tmp_dir.name, tmp_dir.name, f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}'
f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}'
) )
log.debug('Temporary file path: "%s"', tmp_file) log.debug('Temporary file path: "%s"', tmp_file)
with open(tmp_file, 'wb') as file_desc: with open(tmp_file, "wb") as file_desc:
file_desc.write(test_content) file_desc.write(test_content)
log.debug( log.debug(
'Upload file %s to SFTP server (in %s)', tmp_file, "Upload file %s to SFTP server (in %s)",
options.upload_path if options.upload_path else "remote initial connection directory") tmp_file,
options.upload_path if options.upload_path else "remote initial connection directory",
)
if not sftp.upload_file(tmp_file, options.upload_path): if not sftp.upload_file(tmp_file, options.upload_path):
log.error('Fail to upload test file on SFTP server') log.error("Fail to upload test file on SFTP server")
sys.exit(1) sys.exit(1)
log.info('Test file uploaded on SFTP server') log.info("Test file uploaded on SFTP server")
remote_filepath = ( remote_filepath = (
os.path.join(options.upload_path, os.path.basename(tmp_file)) os.path.join(options.upload_path, os.path.basename(tmp_file))
if options.upload_path else os.path.basename(tmp_file) if options.upload_path
else os.path.basename(tmp_file)
) )
with tempfile.NamedTemporaryFile() as tmp_file2: with tempfile.NamedTemporaryFile() as tmp_file2:
log.info('Retrieve test file to %s', tmp_file2.name) log.info("Retrieve test file to %s", tmp_file2.name)
if not sftp.get_file(remote_filepath, tmp_file2.name): if not sftp.get_file(remote_filepath, tmp_file2.name):
log.error('Fail to retrieve test file') log.error("Fail to retrieve test file")
else: else:
with open(tmp_file2.name, 'rb') as file_desc: with open(tmp_file2.name, "rb") as file_desc:
content = file_desc.read() content = file_desc.read()
log.debug('Read content: %s', content) log.debug("Read content: %s", content)
if test_content == content: if test_content == content:
log.info('Content file retrieved match with uploaded one') log.info("Content file retrieved match with uploaded one")
else: else:
log.error('Content file retrieved doest not match with uploaded one') log.error("Content file retrieved doest not match with uploaded one")
try: try:
log.info('Remotly open test file %s', remote_filepath) log.info("Remotly open test file %s", remote_filepath)
file_desc = sftp.open_file(remote_filepath) file_desc = sftp.open_file(remote_filepath)
content = file_desc.read() content = file_desc.read()
log.debug('Read content: %s', content) log.debug("Read content: %s", content)
if test_content == content: if test_content == content:
log.info('Content of remote file match with uploaded one') log.info("Content of remote file match with uploaded one")
else: else:
log.error('Content of remote file doest not match with uploaded one') log.error("Content of remote file doest not match with uploaded one")
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
log.exception('An exception occurred remotly opening test file %s', remote_filepath) log.exception("An exception occurred remotly opening test file %s", remote_filepath)
if sftp.remove_file(remote_filepath): if sftp.remove_file(remote_filepath):
log.info('Test file removed on SFTP server') log.info("Test file removed on SFTP server")
else: else:
log.error('Fail to remove test file on SFTP server') log.error("Fail to remove test file on SFTP server")

View file

@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
""" SFTP client """ """ SFTP client """
import logging import logging
import os import os
from paramiko import SSHClient, AutoAddPolicy, SFTPAttributes from paramiko import AutoAddPolicy, SFTPAttributes, SSHClient
from mylib.config import ConfigurableObject from mylib.config import (
from mylib.config import BooleanOption BooleanOption,
from mylib.config import IntegerOption ConfigurableObject,
from mylib.config import PasswordOption IntegerOption,
from mylib.config import StringOption PasswordOption,
StringOption,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -23,16 +23,16 @@ class SFTPClient(ConfigurableObject):
This class abstract all interactions with the SFTP server. This class abstract all interactions with the SFTP server.
""" """
_config_name = 'sftp' _config_name = "sftp"
_config_comment = 'SFTP' _config_comment = "SFTP"
_defaults = { _defaults = {
'host': 'localhost', "host": "localhost",
'port': 22, "port": 22,
'user': None, "user": None,
'password': None, "password": None,
'known_hosts': os.path.expanduser('~/.ssh/known_hosts'), "known_hosts": os.path.expanduser("~/.ssh/known_hosts"),
'auto_add_unknown_host_key': False, "auto_add_unknown_host_key": False,
'just_try': False, "just_try": False,
} }
ssh_client = None ssh_client = None
@ -45,30 +45,48 @@ class SFTPClient(ConfigurableObject):
section = super().configure(**kwargs) section = super().configure(**kwargs)
section.add_option( section.add_option(
StringOption, 'host', default=self._defaults['host'], StringOption,
comment='SFTP server hostname/IP address') "host",
default=self._defaults["host"],
comment="SFTP server hostname/IP address",
)
section.add_option( section.add_option(
IntegerOption, 'port', default=self._defaults['port'], IntegerOption, "port", default=self._defaults["port"], comment="SFTP server port"
comment='SFTP server port') )
section.add_option( section.add_option(
StringOption, 'user', default=self._defaults['user'], StringOption,
comment='SFTP authentication username') "user",
default=self._defaults["user"],
comment="SFTP authentication username",
)
section.add_option( section.add_option(
PasswordOption, 'password', default=self._defaults['password'], PasswordOption,
"password",
default=self._defaults["password"],
comment='SFTP authentication password (set to "keyring" to use XDG keyring)', comment='SFTP authentication password (set to "keyring" to use XDG keyring)',
username_option='user', keyring_value='keyring') username_option="user",
keyring_value="keyring",
)
section.add_option( section.add_option(
StringOption, 'known_hosts', default=self._defaults['known_hosts'], StringOption,
comment='SFTP known_hosts filepath') "known_hosts",
default=self._defaults["known_hosts"],
comment="SFTP known_hosts filepath",
)
section.add_option( section.add_option(
BooleanOption, 'auto_add_unknown_host_key', BooleanOption,
default=self._defaults['auto_add_unknown_host_key'], "auto_add_unknown_host_key",
comment='Auto add unknown host key') default=self._defaults["auto_add_unknown_host_key"],
comment="Auto add unknown host key",
)
if just_try: if just_try:
section.add_option( section.add_option(
BooleanOption, 'just_try', default=self._defaults['just_try'], BooleanOption,
comment='Just-try mode: do not really make change on remote SFTP host') "just_try",
default=self._defaults["just_try"],
comment="Just-try mode: do not really make change on remote SFTP host",
)
return section return section
@ -80,19 +98,20 @@ class SFTPClient(ConfigurableObject):
"""Connect to SFTP server""" """Connect to SFTP server"""
if self.ssh_client: if self.ssh_client:
return return
host = self._get_option('host') host = self._get_option("host")
port = self._get_option('port') port = self._get_option("port")
log.info("Connect to SFTP server %s:%d", host, port) log.info("Connect to SFTP server %s:%d", host, port)
self.ssh_client = SSHClient() self.ssh_client = SSHClient()
if self._get_option('known_hosts'): if self._get_option("known_hosts"):
self.ssh_client.load_host_keys(self._get_option('known_hosts')) self.ssh_client.load_host_keys(self._get_option("known_hosts"))
if self._get_option('auto_add_unknown_host_key'): if self._get_option("auto_add_unknown_host_key"):
log.debug('Set missing host key policy to auto-add') log.debug("Set missing host key policy to auto-add")
self.ssh_client.set_missing_host_key_policy(AutoAddPolicy()) self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
self.ssh_client.connect( self.ssh_client.connect(
host, port=port, host,
username=self._get_option('user'), port=port,
password=self._get_option('password') username=self._get_option("user"),
password=self._get_option("password"),
) )
self.sftp_client = self.ssh_client.open_sftp() self.sftp_client = self.ssh_client.open_sftp()
self.initial_directory = self.sftp_client.getcwd() self.initial_directory = self.sftp_client.getcwd()
@ -108,7 +127,7 @@ class SFTPClient(ConfigurableObject):
log.debug("Retreive file '%s' to '%s'", remote_filepath, local_filepath) log.debug("Retreive file '%s' to '%s'", remote_filepath, local_filepath)
return self.sftp_client.get(remote_filepath, local_filepath) is None return self.sftp_client.get(remote_filepath, local_filepath) is None
def open_file(self, remote_filepath, mode='r'): def open_file(self, remote_filepath, mode="r"):
"""Remotly open a file on SFTP server""" """Remotly open a file on SFTP server"""
self.connect() self.connect()
log.debug("Remotly open file '%s'", remote_filepath) log.debug("Remotly open file '%s'", remote_filepath)
@ -119,13 +138,13 @@ class SFTPClient(ConfigurableObject):
self.connect() self.connect()
remote_filepath = os.path.join( remote_filepath = os.path.join(
remote_directory if remote_directory else self.initial_directory, remote_directory if remote_directory else self.initial_directory,
os.path.basename(filepath) os.path.basename(filepath),
) )
log.debug("Upload file '%s' to '%s'", filepath, remote_filepath) log.debug("Upload file '%s' to '%s'", filepath, remote_filepath)
if self._get_option('just_try'): if self._get_option("just_try"):
log.debug( log.debug(
"Just-try mode: do not really upload file '%s' to '%s'", "Just-try mode: do not really upload file '%s' to '%s'", filepath, remote_filepath
filepath, remote_filepath) )
return True return True
result = self.sftp_client.put(filepath, remote_filepath) result = self.sftp_client.put(filepath, remote_filepath)
return isinstance(result, SFTPAttributes) return isinstance(result, SFTPAttributes)
@ -134,7 +153,7 @@ class SFTPClient(ConfigurableObject):
"""Remove a file on SFTP server""" """Remove a file on SFTP server"""
self.connect() self.connect()
log.debug("Remove file '%s'", filepath) log.debug("Remove file '%s'", filepath)
if self._get_option('just_try'): if self._get_option("just_try"):
log.debug("Just - try mode: do not really remove file '%s'", filepath) log.debug("Just - try mode: do not really remove file '%s'", filepath)
return True return True
return self.sftp_client.remove(filepath) is None return self.sftp_client.remove(filepath) is None

View file

@ -13,15 +13,15 @@ class TelltaleFile:
def __init__(self, filepath=None, filename=None, dirpath=None): def __init__(self, filepath=None, filename=None, dirpath=None):
assert filepath or filename, "filename or filepath is required" assert filepath or filename, "filename or filepath is required"
if filepath: if filepath:
assert not filename or os.path.basename(filepath) == filename, "filepath and filename does not match" assert (
assert not dirpath or os.path.dirname(filepath) == dirpath, "filepath and dirpath does not match" not filename or os.path.basename(filepath) == filename
), "filepath and filename does not match"
assert (
not dirpath or os.path.dirname(filepath) == dirpath
), "filepath and dirpath does not match"
self.filename = filename if filename else os.path.basename(filepath) self.filename = filename if filename else os.path.basename(filepath)
self.dirpath = ( self.dirpath = (
dirpath if dirpath dirpath if dirpath else (os.path.dirname(filepath) if filepath else os.getcwd())
else (
os.path.dirname(filepath) if filepath
else os.getcwd()
)
) )
self.filepath = filepath if filepath else os.path.join(self.dirpath, self.filename) self.filepath = filepath if filepath else os.path.join(self.dirpath, self.filename)
@ -29,21 +29,19 @@ class TelltaleFile:
def last_update(self): def last_update(self):
"""Retreive last update datetime of the telltall file""" """Retreive last update datetime of the telltall file"""
try: try:
return datetime.datetime.fromtimestamp( return datetime.datetime.fromtimestamp(os.stat(self.filepath).st_mtime)
os.stat(self.filepath).st_mtime
)
except FileNotFoundError: except FileNotFoundError:
log.info('Telltale file not found (%s)', self.filepath) log.info("Telltale file not found (%s)", self.filepath)
return None return None
def update(self): def update(self):
"""Update the telltale file""" """Update the telltale file"""
log.info('Update telltale file (%s)', self.filepath) log.info("Update telltale file (%s)", self.filepath)
try: try:
os.utime(self.filepath, None) os.utime(self.filepath, None)
except FileNotFoundError: except FileNotFoundError:
# pylint: disable=consider-using-with # pylint: disable=consider-using-with
open(self.filepath, 'a', encoding="utf-8").close() open(self.filepath, "a", encoding="utf-8").close()
def remove(self): def remove(self):
"""Remove the telltale file""" """Remove the telltale file"""

View file

@ -1,2 +1,3 @@
[flake8] [flake8]
ignore = E501,W503 ignore = E501,W503
max-line-length = 100

View file

@ -1,76 +1,74 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- """Setuptools script"""
from setuptools import find_packages
from setuptools import setup
from setuptools import find_packages, setup
extras_require = { extras_require = {
'dev': [ "dev": [
'pytest', "pytest",
'mocker', "mocker",
'pytest-mock', "pytest-mock",
'pylint', "pylint == 2.15.10",
'flake8', "pre-commit",
], ],
'config': [ "config": [
'argcomplete', "argcomplete",
'keyring', "keyring",
'systemd-python', "systemd-python",
], ],
'ldap': [ "ldap": [
'python-ldap', "python-ldap",
'python-dateutil', "python-dateutil",
'pytz', "pytz",
], ],
'email': [ "email": [
'mako', "mako",
], ],
'pgsql': [ "pgsql": [
'psycopg2', "psycopg2",
], ],
'oracle': [ "oracle": [
'cx_Oracle', "cx_Oracle",
], ],
'mysql': [ "mysql": [
'mysqlclient', "mysqlclient",
], ],
'sftp': [ "sftp": [
'paramiko', "paramiko",
], ],
} }
install_requires = ['progressbar'] install_requires = ["progressbar"]
for extra, deps in extras_require.items(): for extra, deps in extras_require.items():
if extra != 'dev': if extra != "dev":
install_requires.extend(deps) install_requires.extend(deps)
version = '0.1' version = "0.1"
setup( setup(
name="mylib", name="mylib",
version=version, version=version,
description='A set of helpers small libs to make common tasks easier in my script development', description="A set of helpers small libs to make common tasks easier in my script development",
classifiers=[ classifiers=[
'Programming Language :: Python', "Programming Language :: Python",
], ],
install_requires=install_requires, install_requires=install_requires,
extras_require=extras_require, extras_require=extras_require,
author='Benjamin Renard', author="Benjamin Renard",
author_email='brenard@zionetrix.net', author_email="brenard@zionetrix.net",
url='https://gogs.zionetrix.net/bn8/python-mylib', url="https://gogs.zionetrix.net/bn8/python-mylib",
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
entry_points={ entry_points={
'console_scripts': [ "console_scripts": [
'mylib-test-email = mylib.scripts.email_test:main', "mylib-test-email = mylib.scripts.email_test:main",
'mylib-test-email-with-config = mylib.scripts.email_test_with_config:main', "mylib-test-email-with-config = mylib.scripts.email_test_with_config:main",
'mylib-test-map = mylib.scripts.map_test:main', "mylib-test-map = mylib.scripts.map_test:main",
'mylib-test-pbar = mylib.scripts.pbar_test:main', "mylib-test-pbar = mylib.scripts.pbar_test:main",
'mylib-test-report = mylib.scripts.report_test:main', "mylib-test-report = mylib.scripts.report_test:main",
'mylib-test-ldap = mylib.scripts.ldap_test:main', "mylib-test-ldap = mylib.scripts.ldap_test:main",
'mylib-test-sftp = mylib.scripts.sftp_test:main', "mylib-test-sftp = mylib.scripts.sftp_test:main",
], ],
}, },
) )

View file

@ -22,18 +22,11 @@ echo "Install package with dev dependencies using pip..."
$VENV/bin/python3 -m pip install -e ".[dev]" $QUIET_ARG $VENV/bin/python3 -m pip install -e ".[dev]" $QUIET_ARG
RES=0 RES=0
# Run tests
$VENV/bin/python3 -m pytest tests
[ $? -ne 0 ] && RES=1
# Run pylint # Run pre-commit
echo "Run pylint..." echo "Run pre-commit..."
$VENV/bin/pylint --extension-pkg-whitelist=cx_Oracle mylib tests source $VENV/bin/activate
[ $? -ne 0 ] && RES=1 pre-commit run --all-files
# Run flake8
echo "Run flake8..."
$VENV/bin/flake8 mylib tests
[ $? -ne 0 ] && RES=1 [ $? -ne 0 ] && RES=1
# Clean temporary venv # Clean temporary venv

View file

@ -2,31 +2,29 @@
# pylint: disable=global-variable-not-assigned # pylint: disable=global-variable-not-assigned
""" Tests on config lib """ """ Tests on config lib """
import configparser
import logging import logging
import os import os
import configparser
import pytest import pytest
from mylib.config import Config, ConfigSection from mylib.config import BooleanOption, Config, ConfigSection, StringOption
from mylib.config import BooleanOption
from mylib.config import StringOption
runned = {} runned = {}
def test_config_init_default_args(): def test_config_init_default_args():
appname = 'Test app' appname = "Test app"
config = Config(appname) config = Config(appname)
assert config.appname == appname assert config.appname == appname
assert config.version == '0.0' assert config.version == "0.0"
assert config.encoding == 'utf-8' assert config.encoding == "utf-8"
def test_config_init_custom_args(): def test_config_init_custom_args():
appname = 'Test app' appname = "Test app"
version = '1.43' version = "1.43"
encoding = 'ISO-8859-1' encoding = "ISO-8859-1"
config = Config(appname, version=version, encoding=encoding) config = Config(appname, version=version, encoding=encoding)
assert config.appname == appname assert config.appname == appname
assert config.version == version assert config.version == version
@ -34,8 +32,8 @@ def test_config_init_custom_args():
def test_add_section_default_args(): def test_add_section_default_args():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
section = config.add_section(name) section = config.add_section(name)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
assert config.sections[name] == section assert config.sections[name] == section
@ -45,9 +43,9 @@ def test_add_section_default_args():
def test_add_section_custom_args(): def test_add_section_custom_args():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
comment = 'Test' comment = "Test"
order = 20 order = 20
section = config.add_section(name, comment=comment, order=order) section = config.add_section(name, comment=comment, order=order)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
@ -57,47 +55,47 @@ def test_add_section_custom_args():
def test_add_section_with_callback(): def test_add_section_with_callback():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
global runned global runned
runned['test_add_section_with_callback'] = False runned["test_add_section_with_callback"] = False
def test_callback(loaded_config): def test_callback(loaded_config):
global runned global runned
assert loaded_config == config assert loaded_config == config
assert runned['test_add_section_with_callback'] is False assert runned["test_add_section_with_callback"] is False
runned['test_add_section_with_callback'] = True runned["test_add_section_with_callback"] = True
section = config.add_section(name, loaded_callback=test_callback) section = config.add_section(name, loaded_callback=test_callback)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
assert test_callback in config._loaded_callbacks assert test_callback in config._loaded_callbacks
assert runned['test_add_section_with_callback'] is False assert runned["test_add_section_with_callback"] is False
config.parse_arguments_options(argv=[], create=False) config.parse_arguments_options(argv=[], create=False)
assert runned['test_add_section_with_callback'] is True assert runned["test_add_section_with_callback"] is True
assert test_callback in config._loaded_callbacks_executed assert test_callback in config._loaded_callbacks_executed
# Try to execute again to verify callback is not runned again # Try to execute again to verify callback is not runned again
config._loaded() config._loaded()
def test_add_section_with_callback_already_loaded(): def test_add_section_with_callback_already_loaded():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
config.parse_arguments_options(argv=[], create=False) config.parse_arguments_options(argv=[], create=False)
global runned global runned
runned['test_add_section_with_callback_already_loaded'] = False runned["test_add_section_with_callback_already_loaded"] = False
def test_callback(loaded_config): def test_callback(loaded_config):
global runned global runned
assert loaded_config == config assert loaded_config == config
assert runned['test_add_section_with_callback_already_loaded'] is False assert runned["test_add_section_with_callback_already_loaded"] is False
runned['test_add_section_with_callback_already_loaded'] = True runned["test_add_section_with_callback_already_loaded"] = True
section = config.add_section(name, loaded_callback=test_callback) section = config.add_section(name, loaded_callback=test_callback)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
assert runned['test_add_section_with_callback_already_loaded'] is True assert runned["test_add_section_with_callback_already_loaded"] is True
assert test_callback in config._loaded_callbacks assert test_callback in config._loaded_callbacks
assert test_callback in config._loaded_callbacks_executed assert test_callback in config._loaded_callbacks_executed
# Try to execute again to verify callback is not runned again # Try to execute again to verify callback is not runned again
@ -105,10 +103,10 @@ def test_add_section_with_callback_already_loaded():
def test_add_option_default_args(): def test_add_option_default_args():
config = Config('Test app') config = Config("Test app")
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
name = 'my_option' name = "my_option"
option = section.add_option(StringOption, name) option = section.add_option(StringOption, name)
assert isinstance(option, StringOption) assert isinstance(option, StringOption)
assert name in section.options and section.options[name] == option assert name in section.options and section.options[name] == option
@ -124,17 +122,17 @@ def test_add_option_default_args():
def test_add_option_custom_args(): def test_add_option_custom_args():
config = Config('Test app') config = Config("Test app")
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
name = 'my_option' name = "my_option"
kwargs = dict( kwargs = dict(
default='default value', default="default value",
comment='my comment', comment="my comment",
no_arg=True, no_arg=True,
arg='--my-option', arg="--my-option",
short_arg='-M', short_arg="-M",
arg_help='My help' arg_help="My help",
) )
option = section.add_option(StringOption, name, **kwargs) option = section.add_option(StringOption, name, **kwargs)
assert isinstance(option, StringOption) assert isinstance(option, StringOption)
@ -148,12 +146,12 @@ def test_add_option_custom_args():
def test_defined(): def test_defined():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
assert not config.defined(section_name, opt_name) assert not config.defined(section_name, opt_name)
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
section.add_option(StringOption, opt_name) section.add_option(StringOption, opt_name)
@ -161,29 +159,29 @@ def test_defined():
def test_isset(): def test_isset():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
assert not config.isset(section_name, opt_name) assert not config.isset(section_name, opt_name)
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
option = section.add_option(StringOption, opt_name) option = section.add_option(StringOption, opt_name)
assert not config.isset(section_name, opt_name) assert not config.isset(section_name, opt_name)
config.parse_arguments_options(argv=[option.parser_argument_name, 'value'], create=False) config.parse_arguments_options(argv=[option.parser_argument_name, "value"], create=False)
assert config.isset(section_name, opt_name) assert config.isset(section_name, opt_name)
def test_not_isset(): def test_not_isset():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
assert not config.isset(section_name, opt_name) assert not config.isset(section_name, opt_name)
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
section.add_option(StringOption, opt_name) section.add_option(StringOption, opt_name)
@ -195,11 +193,11 @@ def test_not_isset():
def test_get(): def test_get():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
opt_value = 'value' opt_value = "value"
section = config.add_section('my_section') section = config.add_section("my_section")
option = section.add_option(StringOption, opt_name) option = section.add_option(StringOption, opt_name)
config.parse_arguments_options(argv=[option.parser_argument_name, opt_value], create=False) config.parse_arguments_options(argv=[option.parser_argument_name, opt_value], create=False)
@ -207,11 +205,11 @@ def test_get():
def test_get_default(): def test_get_default():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
opt_default_value = 'value' opt_default_value = "value"
section = config.add_section('my_section') section = config.add_section("my_section")
section.add_option(StringOption, opt_name, default=opt_default_value) section.add_option(StringOption, opt_name, default=opt_default_value)
config.parse_arguments_options(argv=[], create=False) config.parse_arguments_options(argv=[], create=False)
@ -219,8 +217,8 @@ def test_get_default():
def test_logging_splited_stdout_stderr(capsys): def test_logging_splited_stdout_stderr(capsys):
config = Config('Test app') config = Config("Test app")
config.parse_arguments_options(argv=['-C', '-v'], create=False) config.parse_arguments_options(argv=["-C", "-v"], create=False)
info_msg = "[info]" info_msg = "[info]"
err_msg = "[error]" err_msg = "[error]"
logging.getLogger().info(info_msg) logging.getLogger().info(info_msg)
@ -239,9 +237,9 @@ def test_logging_splited_stdout_stderr(capsys):
@pytest.fixture() @pytest.fixture()
def config_with_file(tmpdir): def config_with_file(tmpdir):
config = Config('Test app') config = Config("Test app")
config_dir = tmpdir.mkdir('config') config_dir = tmpdir.mkdir("config")
config_file = config_dir.join('config.ini') config_file = config_dir.join("config.ini")
config.save(os.path.join(config_file.dirname, config_file.basename)) config.save(os.path.join(config_file.dirname, config_file.basename))
return config return config
@ -250,6 +248,7 @@ def generate_mock_input(expected_prompt, input_value):
def mock_input(self, prompt): # pylint: disable=unused-argument def mock_input(self, prompt): # pylint: disable=unused-argument
assert prompt == expected_prompt assert prompt == expected_prompt
return input_value return input_value
return mock_input return mock_input
@ -257,10 +256,9 @@ def generate_mock_input(expected_prompt, input_value):
def test_boolean_option_from_config(config_with_file): def test_boolean_option_from_config(config_with_file):
section = config_with_file.add_section('test') section = config_with_file.add_section("test")
default = True default = True
option = section.add_option( option = section.add_option(BooleanOption, "test_bool", default=default)
BooleanOption, 'test_bool', default=default)
config_with_file.save() config_with_file.save()
option.set(not default) option.set(not default)
@ -273,74 +271,76 @@ def test_boolean_option_from_config(config_with_file):
def test_boolean_option_ask_value(mocker): def test_boolean_option_ask_value(mocker):
config = Config('Test app') config = Config("Test app")
section = config.add_section('test') section = config.add_section("test")
name = 'test_bool' name = "test_bool"
option = section.add_option( option = section.add_option(BooleanOption, name, default=True)
BooleanOption, name, default=True)
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "y")
generate_mock_input(f'{name}: [Y/n] ', 'y')
) )
assert option.ask_value(set_it=False) is True assert option.ask_value(set_it=False) is True
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "Y")
generate_mock_input(f'{name}: [Y/n] ', 'Y')
) )
assert option.ask_value(set_it=False) is True assert option.ask_value(set_it=False) is True
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "")
generate_mock_input(f'{name}: [Y/n] ', '')
) )
assert option.ask_value(set_it=False) is True assert option.ask_value(set_it=False) is True
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "n")
generate_mock_input(f'{name}: [Y/n] ', 'n')
) )
assert option.ask_value(set_it=False) is False assert option.ask_value(set_it=False) is False
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "N")
generate_mock_input(f'{name}: [Y/n] ', 'N')
) )
assert option.ask_value(set_it=False) is False assert option.ask_value(set_it=False) is False
def test_boolean_option_to_config(): def test_boolean_option_to_config():
config = Config('Test app') config = Config("Test app")
section = config.add_section('test') section = config.add_section("test")
default = True default = True
option = section.add_option(BooleanOption, 'test_bool', default=default) option = section.add_option(BooleanOption, "test_bool", default=default)
assert option.to_config(True) == 'true' assert option.to_config(True) == "true"
assert option.to_config(False) == 'false' assert option.to_config(False) == "false"
def test_boolean_option_export_to_config(config_with_file): def test_boolean_option_export_to_config(config_with_file):
section = config_with_file.add_section('test') section = config_with_file.add_section("test")
name = 'test_bool' name = "test_bool"
comment = 'Test boolean' comment = "Test boolean"
default = True default = True
option = section.add_option( option = section.add_option(BooleanOption, name, default=default, comment=comment)
BooleanOption, name, default=default, comment=comment)
assert option.export_to_config() == f"""# {comment} assert (
option.export_to_config()
== f"""# {comment}
# Default: {str(default).lower()} # Default: {str(default).lower()}
# {name} = # {name} =
""" """
)
option.set(not default) option.set(not default)
assert option.export_to_config() == f"""# {comment} assert (
option.export_to_config()
== f"""# {comment}
# Default: {str(default).lower()} # Default: {str(default).lower()}
{name} = {str(not default).lower()} {name} = {str(not default).lower()}
""" """
)
option.set(default) option.set(default)
assert option.export_to_config() == f"""# {comment} assert (
option.export_to_config()
== f"""# {comment}
# Default: {str(default).lower()} # Default: {str(default).lower()}
# {name} = # {name} =
""" """
)

View file

@ -2,7 +2,6 @@
""" Tests on opening hours helpers """ """ Tests on opening hours helpers """
import pytest import pytest
from MySQLdb._exceptions import Error from MySQLdb._exceptions import Error
from mylib.mysql import MyDB from mylib.mysql import MyDB
@ -11,7 +10,9 @@ from mylib.mysql import MyDB
class FakeMySQLdbCursor: class FakeMySQLdbCursor:
"""Fake MySQLdb cursor""" """Fake MySQLdb cursor"""
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
self.expected_sql = expected_sql self.expected_sql = expected_sql
self.expected_params = expected_params self.expected_params = expected_params
self.expected_return = expected_return self.expected_return = expected_return
@ -20,13 +21,25 @@ class FakeMySQLdbCursor:
def execute(self, sql, params=None): def execute(self, sql, params=None):
if self.expected_exception: if self.expected_exception:
raise Error(f'{self}.execute({sql}, {params}): expected exception') raise Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith('select '): if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
return self.expected_return return self.expected_return
@property @property
@ -39,16 +52,14 @@ class FakeMySQLdbCursor:
def fetchall(self): def fetchall(self):
if isinstance(self.expected_return, list): if isinstance(self.expected_return, list):
return ( return (
list(row.values()) list(row.values()) if isinstance(row, dict) else row for row in self.expected_return
if isinstance(row, dict) else row
for row in self.expected_return
) )
return self.expected_return return self.expected_return
def __repr__(self): def __repr__(self):
return ( return (
f'FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, ' f"FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, "
f'{self.expected_return}, {self.expected_just_try})' f"{self.expected_return}, {self.expected_just_try})"
) )
@ -63,11 +74,14 @@ class FakeMySQLdb:
just_try = False just_try = False
def __init__(self, **kwargs): def __init__(self, **kwargs):
allowed_kwargs = dict(db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool) allowed_kwargs = dict(
db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool
)
for arg, value in kwargs.items(): for arg, value in kwargs.items():
assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"' assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"'
assert isinstance(value, allowed_kwargs[arg]), \ assert isinstance(
f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})' value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value) setattr(self, arg, value)
def close(self): def close(self):
@ -75,9 +89,11 @@ class FakeMySQLdb:
def cursor(self): def cursor(self):
return FakeMySQLdbCursor( return FakeMySQLdbCursor(
self.expected_sql, self.expected_params, self.expected_sql,
self.expected_return, self.expected_just_try or self.just_try, self.expected_params,
self.expected_exception self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
) )
def commit(self): def commit(self):
@ -105,19 +121,19 @@ def fake_mysqldb_connect_just_try(**kwargs):
@pytest.fixture @pytest.fixture
def test_mydb(): def test_mydb():
return MyDB('127.0.0.1', 'user', 'password', 'dbname') return MyDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_mydb(mocker): def fake_mydb(mocker):
mocker.patch('MySQLdb.connect', fake_mysqldb_connect) mocker.patch("MySQLdb.connect", fake_mysqldb_connect)
return MyDB('127.0.0.1', 'user', 'password', 'dbname') return MyDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_just_try_mydb(mocker): def fake_just_try_mydb(mocker):
mocker.patch('MySQLdb.connect', fake_mysqldb_connect_just_try) mocker.patch("MySQLdb.connect", fake_mysqldb_connect_just_try)
return MyDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True) return MyDB("127.0.0.1", "user", "password", "dbname", just_try=True)
@pytest.fixture @pytest.fixture
@ -132,13 +148,22 @@ def fake_connected_just_try_mydb(fake_just_try_mydb):
return fake_just_try_mydb return fake_just_try_mydb
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs): def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
return expected_return return expected_return
return mock_args return mock_args
@ -146,13 +171,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode" assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
return expected_return return expected_return
return mock_doSQL return mock_doSQL
@ -166,15 +200,11 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter(): def test_combine_params_with_to_add_parameter():
assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict( assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs(): def test_combine_params_with_kargs():
assert MyDB._combine_params(dict(test1=1), test2=2) == dict( assert MyDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter(): def test_combine_params_with_kargs_and_to_add_parameter():
@ -184,47 +214,40 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved(): def test_format_where_clauses_params_are_preserved():
args = ('test = test', dict(test1=1)) args = ("test = test", dict(test1=1))
assert MyDB._format_where_clauses(*args) == args assert MyDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw(): def test_format_where_clauses_raw():
assert MyDB._format_where_clauses('test = test') == (('test = test'), {}) assert MyDB._format_where_clauses("test = test") == ("test = test", {})
def test_format_where_clauses_tuple_clause_with_params(): def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ( where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2))
'test1 = %(test1)s AND test2 = %(test2)s',
dict(test1=1, test2=2)
)
assert MyDB._format_where_clauses(where_clauses) == where_clauses assert MyDB._format_where_clauses(where_clauses) == where_clauses
def test_format_where_clauses_dict(): def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert MyDB._format_where_clauses(where_clauses) == ( assert MyDB._format_where_clauses(where_clauses) == (
'`test1` = %(test1)s AND `test2` = %(test2)s', "`test1` = %(test1)s AND `test2` = %(test2)s",
where_clauses where_clauses,
) )
def test_format_where_clauses_combined_types(): def test_format_where_clauses_combined_types():
where_clauses = ( where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4))
'test1 = 1',
('test2 LIKE %(test2)s', dict(test2=2)),
dict(test3=3, test4=4)
)
assert MyDB._format_where_clauses(where_clauses) == ( assert MyDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s', "test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s",
dict(test2=2, test3=3, test4=4) dict(test2=2, test3=3, test4=4),
) )
def test_format_where_clauses_with_where_op(): def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert MyDB._format_where_clauses(where_clauses, where_op='OR') == ( assert MyDB._format_where_clauses(where_clauses, where_op="OR") == (
'`test1` = %(test1)s OR `test2` = %(test2)s', "`test1` = %(test1)s OR `test2` = %(test2)s",
where_clauses where_clauses,
) )
@ -232,8 +255,8 @@ def test_add_where_clauses():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert MyDB._add_where_clauses(sql, None, where_clauses) == ( assert MyDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s', sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s",
where_clauses where_clauses,
) )
@ -242,106 +265,102 @@ def test_add_where_clauses_preserved_params():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
params = dict(fake1=1) params = dict(fake1=1)
assert MyDB._add_where_clauses(sql, params.copy(), where_clauses) == ( assert MyDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s', sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s",
dict(**where_clauses, **params) dict(**where_clauses, **params),
) )
def test_add_where_clauses_with_op(): def test_add_where_clauses_with_op():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = ('test1=1', 'test2=2') where_clauses = ("test1=1", "test2=2")
assert MyDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( assert MyDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + ' WHERE test1=1 OR test2=2', sql + " WHERE test1=1 OR test2=2",
{} {},
) )
def test_add_where_clauses_with_duplicated_field(): def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=%(test1)s" sql = "UPDATE table SET test1=%(test1)s"
params = dict(test1='new_value') params = dict(test1="new_value")
where_clauses = dict(test1='where_value') where_clauses = dict(test1="where_value")
assert MyDB._add_where_clauses(sql, params, where_clauses) == ( assert MyDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE `test1` = %(test1_1)s', sql + " WHERE `test1` = %(test1_1)s",
dict(test1='new_value', test1_1='where_value') dict(test1="new_value", test1_1="where_value"),
) )
def test_quote_table_name(): def test_quote_table_name():
assert MyDB._quote_table_name("mytable") == '`mytable`' assert MyDB._quote_table_name("mytable") == "`mytable`"
assert MyDB._quote_table_name("myschema.mytable") == '`myschema`.`mytable`' assert MyDB._quote_table_name("myschema.mytable") == "`myschema`.`mytable`"
def test_insert(mocker, test_mydb): def test_insert(mocker, test_mydb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSQL', "mylib.mysql.MyDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)', "INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)", values
values ),
)
) )
assert test_mydb.insert('mytable', values) assert test_mydb.insert("mytable", values)
def test_insert_just_try(mocker, test_mydb): def test_insert_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.insert('mytable', dict(test1=1, test2=2), just_try=True) assert test_mydb.insert("mytable", dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_mydb): def test_update(mocker, test_mydb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSQL', "mylib.mysql.MyDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` = %(test3)s AND `test4` = %(test4)s', "UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` ="
dict(**values, **where_clauses) " %(test3)s AND `test4` = %(test4)s",
) dict(**values, **where_clauses),
),
) )
assert test_mydb.update('mytable', values, where_clauses) assert test_mydb.update("mytable", values, where_clauses)
def test_update_just_try(mocker, test_mydb): def test_update_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.update('mytable', dict(test1=1, test2=2), None, just_try=True) assert test_mydb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_mydb): def test_delete(mocker, test_mydb):
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSQL', "mylib.mysql.MyDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s', "DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s", where_clauses
where_clauses ),
)
) )
assert test_mydb.delete('mytable', where_clauses) assert test_mydb.delete("mytable", where_clauses)
def test_delete_just_try(mocker, test_mydb): def test_delete_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.delete('mytable', None, just_try=True) assert test_mydb.delete("mytable", None, just_try=True)
def test_truncate(mocker, test_mydb): def test_truncate(mocker, test_mydb):
mocker.patch( mocker.patch("mylib.mysql.MyDB.doSQL", generate_mock_doSQL("TRUNCATE TABLE `mytable`", None))
'mylib.mysql.MyDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE `mytable`', None)
)
assert test_mydb.truncate('mytable') assert test_mydb.truncate("mytable")
def test_truncate_just_try(mocker, test_mydb): def test_truncate_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try)
assert test_mydb.truncate('mytable', just_try=True) assert test_mydb.truncate("mytable", just_try=True)
def test_select(mocker, test_mydb): def test_select(mocker, test_mydb):
fields = ('field1', 'field2') fields = ("field1", "field2")
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
expected_return = [ expected_return = [
dict(field1=1, field2=2), dict(field1=1, field2=2),
@ -349,30 +368,28 @@ def test_select(mocker, test_mydb):
] ]
order_by = "field1, DESC" order_by = "field1, DESC"
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSelect', "mylib.mysql.MyDB.doSelect",
generate_mock_doSQL( generate_mock_doSQL(
'SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` = %(test4)s ORDER BY ' + order_by, "SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` ="
where_clauses, expected_return " %(test4)s ORDER BY " + order_by,
) where_clauses,
expected_return,
),
) )
assert test_mydb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return assert test_mydb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_mydb): def test_select_without_field_and_order_by(mocker, test_mydb):
mocker.patch( mocker.patch("mylib.mysql.MyDB.doSelect", generate_mock_doSQL("SELECT * FROM `mytable`"))
'mylib.mysql.MyDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM `mytable`'
)
)
assert test_mydb.select('mytable') assert test_mydb.select("mytable")
def test_select_just_try(mocker, test_mydb): def test_select_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try)
assert test_mydb.select('mytable', None, None, just_try=True) assert test_mydb.select("mytable", None, None, just_try=True)
# #
# Tests on main methods # Tests on main methods
@ -389,12 +406,7 @@ def test_connect(mocker, test_mydb):
use_unicode=True, use_unicode=True,
) )
mocker.patch( mocker.patch("MySQLdb.connect", generate_mock_args(expected_kwargs=expected_kwargs))
'MySQLdb.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_mydb.connect() assert test_mydb.connect()
@ -408,48 +420,61 @@ def test_close_connected(fake_connected_mydb):
def test_doSQL(fake_connected_mydb): def test_doSQL(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s' fake_connected_mydb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s"
fake_connected_mydb._conn.expected_params = dict(test1=1) fake_connected_mydb._conn.expected_params = dict(test1=1)
fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) fake_connected_mydb.doSQL(
fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params
)
def test_doSQL_without_params(fake_connected_mydb): def test_doSQL_without_params(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'DELETE FROM table' fake_connected_mydb._conn.expected_sql = "DELETE FROM table"
fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql) fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_mydb): def test_doSQL_just_try(fake_connected_just_try_mydb):
assert fake_connected_just_try_mydb.doSQL('DELETE FROM table') assert fake_connected_just_try_mydb.doSQL("DELETE FROM table")
def test_doSQL_on_exception(fake_connected_mydb): def test_doSQL_on_exception(fake_connected_mydb):
fake_connected_mydb._conn.expected_exception = True fake_connected_mydb._conn.expected_exception = True
assert fake_connected_mydb.doSQL('DELETE FROM table') is False assert fake_connected_mydb.doSQL("DELETE FROM table") is False
def test_doSelect(fake_connected_mydb): def test_doSelect(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_mydb._conn.expected_params = dict(test1=1) fake_connected_mydb._conn.expected_params = dict(test1=1)
fake_connected_mydb._conn.expected_return = [dict(test1=1)] fake_connected_mydb._conn.expected_return = [dict(test1=1)]
assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) == fake_connected_mydb._conn.expected_return assert (
fake_connected_mydb.doSelect(
fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params
)
== fake_connected_mydb._conn.expected_return
)
def test_doSelect_without_params(fake_connected_mydb): def test_doSelect_without_params(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table' fake_connected_mydb._conn.expected_sql = "SELECT * FROM table"
fake_connected_mydb._conn.expected_return = [dict(test1=1)] fake_connected_mydb._conn.expected_return = [dict(test1=1)]
assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql) == fake_connected_mydb._conn.expected_return assert (
fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql)
== fake_connected_mydb._conn.expected_return
)
def test_doSelect_on_exception(fake_connected_mydb): def test_doSelect_on_exception(fake_connected_mydb):
fake_connected_mydb._conn.expected_exception = True fake_connected_mydb._conn.expected_exception = True
assert fake_connected_mydb.doSelect('SELECT * FROM table') is False assert fake_connected_mydb.doSelect("SELECT * FROM table") is False
def test_doSelect_just_try(fake_connected_just_try_mydb): def test_doSelect_just_try(fake_connected_just_try_mydb):
fake_connected_just_try_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_just_try_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_just_try_mydb._conn.expected_params = dict(test1=1) fake_connected_just_try_mydb._conn.expected_params = dict(test1=1)
fake_connected_just_try_mydb._conn.expected_return = [dict(test1=1)] fake_connected_just_try_mydb._conn.expected_return = [dict(test1=1)]
assert fake_connected_just_try_mydb.doSelect( assert (
fake_connected_just_try_mydb.doSelect(
fake_connected_just_try_mydb._conn.expected_sql, fake_connected_just_try_mydb._conn.expected_sql,
fake_connected_just_try_mydb._conn.expected_params fake_connected_just_try_mydb._conn.expected_params,
) == fake_connected_just_try_mydb._conn.expected_return )
== fake_connected_just_try_mydb._conn.expected_return
)

View file

@ -2,6 +2,7 @@
""" Tests on opening hours helpers """ """ Tests on opening hours helpers """
import datetime import datetime
import pytest import pytest
from mylib import opening_hours from mylib import opening_hours
@ -12,14 +13,16 @@ from mylib import opening_hours
def test_parse_exceptional_closures_one_day_without_time_period(): def test_parse_exceptional_closures_one_day_without_time_period():
assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [{'days': [datetime.date(2017, 9, 22)], 'hours_periods': []}] assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [
{"days": [datetime.date(2017, 9, 22)], "hours_periods": []}
]
def test_parse_exceptional_closures_one_day_with_time_period(): def test_parse_exceptional_closures_one_day_with_time_period():
assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30"]) == [ assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30"]) == [
{ {
'days': [datetime.date(2017, 11, 26)], "days": [datetime.date(2017, 11, 26)],
'hours_periods': [{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}] "hours_periods": [{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}],
} }
] ]
@ -27,11 +30,11 @@ def test_parse_exceptional_closures_one_day_with_time_period():
def test_parse_exceptional_closures_one_day_with_multiple_time_periods(): def test_parse_exceptional_closures_one_day_with_multiple_time_periods():
assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30 14h-18h"]) == [ assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30 14h-18h"]) == [
{ {
'days': [datetime.date(2017, 11, 26)], "days": [datetime.date(2017, 11, 26)],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
] ],
} }
] ]
@ -39,8 +42,12 @@ def test_parse_exceptional_closures_one_day_with_multiple_time_periods():
def test_parse_exceptional_closures_full_days_period(): def test_parse_exceptional_closures_full_days_period():
assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017"]) == [ assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017"]) == [
{ {
'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)], "days": [
'hours_periods': [] datetime.date(2017, 9, 20),
datetime.date(2017, 9, 21),
datetime.date(2017, 9, 22),
],
"hours_periods": [],
} }
] ]
@ -53,8 +60,12 @@ def test_parse_exceptional_closures_invalid_days_period():
def test_parse_exceptional_closures_days_period_with_time_period(): def test_parse_exceptional_closures_days_period_with_time_period():
assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017 9h-12h"]) == [ assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017 9h-12h"]) == [
{ {
'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)], "days": [
'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}] datetime.date(2017, 9, 20),
datetime.date(2017, 9, 21),
datetime.date(2017, 9, 22),
],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}],
} }
] ]
@ -70,31 +81,38 @@ def test_parse_exceptional_closures_invalid_time_period():
def test_parse_exceptional_closures_multiple_periods(): def test_parse_exceptional_closures_multiple_periods():
assert opening_hours.parse_exceptional_closures(["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]) == [ assert opening_hours.parse_exceptional_closures(
["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]
) == [
{ {
'days': [ "days": [
datetime.date(2017, 9, 20), datetime.date(2017, 9, 20),
datetime.date(2017, 11, 25), datetime.date(2017, 11, 25),
datetime.date(2017, 11, 26), datetime.date(2017, 11, 26),
], ],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
] ],
} }
] ]
# #
# Tests on parse_normal_opening_hours() # Tests on parse_normal_opening_hours()
# #
def test_parse_normal_opening_hours_one_day(): def test_parse_normal_opening_hours_one_day():
assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [{'days': ["jeudi"], 'hours_periods': []}] assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [
{"days": ["jeudi"], "hours_periods": []}
]
def test_parse_normal_opening_hours_multiple_days(): def test_parse_normal_opening_hours_multiple_days():
assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [{'days': ["lundi", "jeudi"], 'hours_periods': []}] assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [
{"days": ["lundi", "jeudi"], "hours_periods": []}
]
def test_parse_normal_opening_hours_invalid_day(): def test_parse_normal_opening_hours_invalid_day():
@ -104,13 +122,17 @@ def test_parse_normal_opening_hours_invalid_day():
def test_parse_normal_opening_hours_one_days_period(): def test_parse_normal_opening_hours_one_days_period():
assert opening_hours.parse_normal_opening_hours(["lundi-jeudi"]) == [ assert opening_hours.parse_normal_opening_hours(["lundi-jeudi"]) == [
{'days': ["lundi", "mardi", "mercredi", "jeudi"], 'hours_periods': []} {"days": ["lundi", "mardi", "mercredi", "jeudi"], "hours_periods": []}
] ]
def test_parse_normal_opening_hours_one_day_with_one_time_period(): def test_parse_normal_opening_hours_one_day_with_one_time_period():
assert opening_hours.parse_normal_opening_hours(["jeudi 9h-12h"]) == [ assert opening_hours.parse_normal_opening_hours(["jeudi 9h-12h"]) == [
{'days': ["jeudi"], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}]}] {
"days": ["jeudi"],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}],
}
]
def test_parse_normal_opening_hours_invalid_days_period(): def test_parse_normal_opening_hours_invalid_days_period():
@ -122,7 +144,10 @@ def test_parse_normal_opening_hours_invalid_days_period():
def test_parse_normal_opening_hours_one_time_period(): def test_parse_normal_opening_hours_one_time_period():
assert opening_hours.parse_normal_opening_hours(["9h-18h30"]) == [ assert opening_hours.parse_normal_opening_hours(["9h-18h30"]) == [
{'days': [], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(18, 30)}]} {
"days": [],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(18, 30)}],
}
] ]
@ -132,48 +157,60 @@ def test_parse_normal_opening_hours_invalid_time_period():
def test_parse_normal_opening_hours_multiple_periods(): def test_parse_normal_opening_hours_multiple_periods():
assert opening_hours.parse_normal_opening_hours(["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]) == [ assert opening_hours.parse_normal_opening_hours(
["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]
) == [
{ {
'days': ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi'], "days": ["lundi", "mardi", "mercredi", "jeudi", "vendredi"],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
] ],
}, },
{ {
'days': ['samedi'], "days": ["samedi"],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(18, 0)}, {"start": datetime.time(9, 30), "stop": datetime.time(18, 0)},
] ],
}, },
{ {
'days': ['dimanche'], "days": ["dimanche"],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 0)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 0)},
] ],
}, },
] ]
# #
# Tests on is_closed # Tests on is_closed
# #
exceptional_closures = ["22/09/2017", "20/09/2017-22/09/2017", "20/09/2017-22/09/2017 18/09/2017", "25/11/2017", "26/11/2017 9h30-12h30"] exceptional_closures = [
normal_opening_hours = ["lundi-mardi jeudi 9h30-12h30 14h-16h30", "mercredi vendredi 9h30-12h30 14h-17h"] "22/09/2017",
"20/09/2017-22/09/2017",
"20/09/2017-22/09/2017 18/09/2017",
"25/11/2017",
"26/11/2017 9h30-12h30",
]
normal_opening_hours = [
"lundi-mardi jeudi 9h30-12h30 14h-16h30",
"mercredi vendredi 9h30-12h30 14h-17h",
]
nonworking_public_holidays = [ nonworking_public_holidays = [
'1janvier', "1janvier",
'paques', "paques",
'lundi_paques', "lundi_paques",
'1mai', "1mai",
'8mai', "8mai",
'jeudi_ascension', "jeudi_ascension",
'lundi_pentecote', "lundi_pentecote",
'14juillet', "14juillet",
'15aout', "15aout",
'1novembre', "1novembre",
'11novembre', "11novembre",
'noel', "noel",
] ]
@ -182,12 +219,8 @@ def test_is_closed_when_normaly_closed_by_hour():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 1, 20, 15) when=datetime.datetime(2017, 5, 1, 20, 15),
) == { ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_on_exceptional_closure_full_day(): def test_is_closed_on_exceptional_closure_full_day():
@ -195,12 +228,8 @@ def test_is_closed_on_exceptional_closure_full_day():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 9, 22, 14, 15) when=datetime.datetime(2017, 9, 22, 14, 15),
) == { ) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": True}
'closed': True,
'exceptional_closure': True,
'exceptional_closure_all_day': True
}
def test_is_closed_on_exceptional_closure_day(): def test_is_closed_on_exceptional_closure_day():
@ -208,12 +237,8 @@ def test_is_closed_on_exceptional_closure_day():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 11, 26, 10, 30) when=datetime.datetime(2017, 11, 26, 10, 30),
) == { ) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': True,
'exceptional_closure_all_day': False
}
def test_is_closed_on_nonworking_public_holidays(): def test_is_closed_on_nonworking_public_holidays():
@ -221,12 +246,8 @@ def test_is_closed_on_nonworking_public_holidays():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 1, 1, 10, 30) when=datetime.datetime(2017, 1, 1, 10, 30),
) == { ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_when_normaly_closed_by_day(): def test_is_closed_when_normaly_closed_by_day():
@ -234,12 +255,8 @@ def test_is_closed_when_normaly_closed_by_day():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 6, 14, 15) when=datetime.datetime(2017, 5, 6, 14, 15),
) == { ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_when_normaly_opened(): def test_is_closed_when_normaly_opened():
@ -247,12 +264,8 @@ def test_is_closed_when_normaly_opened():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 2, 15, 15) when=datetime.datetime(2017, 5, 2, 15, 15),
) == { ) == {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': False,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_easter_date(): def test_easter_date():
@ -272,18 +285,18 @@ def test_easter_date():
def test_nonworking_french_public_days_of_the_year(): def test_nonworking_french_public_days_of_the_year():
assert opening_hours.nonworking_french_public_days_of_the_year(2021) == { assert opening_hours.nonworking_french_public_days_of_the_year(2021) == {
'1janvier': datetime.date(2021, 1, 1), "1janvier": datetime.date(2021, 1, 1),
'paques': datetime.date(2021, 4, 4), "paques": datetime.date(2021, 4, 4),
'lundi_paques': datetime.date(2021, 4, 5), "lundi_paques": datetime.date(2021, 4, 5),
'1mai': datetime.date(2021, 5, 1), "1mai": datetime.date(2021, 5, 1),
'8mai': datetime.date(2021, 5, 8), "8mai": datetime.date(2021, 5, 8),
'jeudi_ascension': datetime.date(2021, 5, 13), "jeudi_ascension": datetime.date(2021, 5, 13),
'pentecote': datetime.date(2021, 5, 23), "pentecote": datetime.date(2021, 5, 23),
'lundi_pentecote': datetime.date(2021, 5, 24), "lundi_pentecote": datetime.date(2021, 5, 24),
'14juillet': datetime.date(2021, 7, 14), "14juillet": datetime.date(2021, 7, 14),
'15aout': datetime.date(2021, 8, 15), "15aout": datetime.date(2021, 8, 15),
'1novembre': datetime.date(2021, 11, 1), "1novembre": datetime.date(2021, 11, 1),
'11novembre': datetime.date(2021, 11, 11), "11novembre": datetime.date(2021, 11, 11),
'noel': datetime.date(2021, 12, 25), "noel": datetime.date(2021, 12, 25),
'saint_etienne': datetime.date(2021, 12, 26) "saint_etienne": datetime.date(2021, 12, 26),
} }

View file

@ -10,7 +10,9 @@ from mylib.oracle import OracleDB
class FakeCXOracleCursor: class FakeCXOracleCursor:
"""Fake cx_Oracle cursor""" """Fake cx_Oracle cursor"""
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
self.expected_sql = expected_sql self.expected_sql = expected_sql
self.expected_params = expected_params self.expected_params = expected_params
self.expected_return = expected_return self.expected_return = expected_return
@ -21,13 +23,25 @@ class FakeCXOracleCursor:
def execute(self, sql, **params): def execute(self, sql, **params):
assert self.opened assert self.opened
if self.expected_exception: if self.expected_exception:
raise cx_Oracle.Error(f'{self}.execute({sql}, {params}): expected exception') raise cx_Oracle.Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith('select '): if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
return self.expected_return return self.expected_return
def fetchall(self): def fetchall(self):
@ -43,8 +57,8 @@ class FakeCXOracleCursor:
def __repr__(self): def __repr__(self):
return ( return (
f'FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, ' f"FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, "
f'{self.expected_return}, {self.expected_just_try})' f"{self.expected_return}, {self.expected_just_try})"
) )
@ -62,7 +76,9 @@ class FakeCXOracle:
allowed_kwargs = dict(dsn=str, user=str, password=(str, None)) allowed_kwargs = dict(dsn=str, user=str, password=(str, None))
for arg, value in kwargs.items(): for arg, value in kwargs.items():
assert arg in allowed_kwargs, f"Invalid arg {arg}='{value}'" assert arg in allowed_kwargs, f"Invalid arg {arg}='{value}'"
assert isinstance(value, allowed_kwargs[arg]), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})" assert isinstance(
value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value) setattr(self, arg, value)
def close(self): def close(self):
@ -70,9 +86,11 @@ class FakeCXOracle:
def cursor(self): def cursor(self):
return FakeCXOracleCursor( return FakeCXOracleCursor(
self.expected_sql, self.expected_params, self.expected_sql,
self.expected_return, self.expected_just_try or self.just_try, self.expected_params,
self.expected_exception self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
) )
def commit(self): def commit(self):
@ -100,19 +118,19 @@ def fake_cxoracle_connect_just_try(**kwargs):
@pytest.fixture @pytest.fixture
def test_oracledb(): def test_oracledb():
return OracleDB('127.0.0.1/dbname', 'user', 'password') return OracleDB("127.0.0.1/dbname", "user", "password")
@pytest.fixture @pytest.fixture
def fake_oracledb(mocker): def fake_oracledb(mocker):
mocker.patch('cx_Oracle.connect', fake_cxoracle_connect) mocker.patch("cx_Oracle.connect", fake_cxoracle_connect)
return OracleDB('127.0.0.1/dbname', 'user', 'password') return OracleDB("127.0.0.1/dbname", "user", "password")
@pytest.fixture @pytest.fixture
def fake_just_try_oracledb(mocker): def fake_just_try_oracledb(mocker):
mocker.patch('cx_Oracle.connect', fake_cxoracle_connect_just_try) mocker.patch("cx_Oracle.connect", fake_cxoracle_connect_just_try)
return OracleDB('127.0.0.1/dbname', 'user', 'password', just_try=True) return OracleDB("127.0.0.1/dbname", "user", "password", just_try=True)
@pytest.fixture @pytest.fixture
@ -127,13 +145,22 @@ def fake_connected_just_try_oracledb(fake_just_try_oracledb):
return fake_just_try_oracledb return fake_just_try_oracledb
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs): def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
return expected_return return expected_return
return mock_args return mock_args
@ -141,13 +168,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode" assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
return expected_return return expected_return
return mock_doSQL return mock_doSQL
@ -161,15 +197,11 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter(): def test_combine_params_with_to_add_parameter():
assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict( assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs(): def test_combine_params_with_kargs():
assert OracleDB._combine_params(dict(test1=1), test2=2) == dict( assert OracleDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter(): def test_combine_params_with_kargs_and_to_add_parameter():
@ -179,19 +211,16 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved(): def test_format_where_clauses_params_are_preserved():
args = ('test = test', dict(test1=1)) args = ("test = test", dict(test1=1))
assert OracleDB._format_where_clauses(*args) == args assert OracleDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw(): def test_format_where_clauses_raw():
assert OracleDB._format_where_clauses('test = test') == (('test = test'), {}) assert OracleDB._format_where_clauses("test = test") == ("test = test", {})
def test_format_where_clauses_tuple_clause_with_params(): def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ( where_clauses = ("test1 = :test1 AND test2 = :test2", dict(test1=1, test2=2))
'test1 = :test1 AND test2 = :test2',
dict(test1=1, test2=2)
)
assert OracleDB._format_where_clauses(where_clauses) == where_clauses assert OracleDB._format_where_clauses(where_clauses) == where_clauses
@ -199,27 +228,23 @@ def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert OracleDB._format_where_clauses(where_clauses) == ( assert OracleDB._format_where_clauses(where_clauses) == (
'"test1" = :test1 AND "test2" = :test2', '"test1" = :test1 AND "test2" = :test2',
where_clauses where_clauses,
) )
def test_format_where_clauses_combined_types(): def test_format_where_clauses_combined_types():
where_clauses = ( where_clauses = ("test1 = 1", ("test2 LIKE :test2", dict(test2=2)), dict(test3=3, test4=4))
'test1 = 1',
('test2 LIKE :test2', dict(test2=2)),
dict(test3=3, test4=4)
)
assert OracleDB._format_where_clauses(where_clauses) == ( assert OracleDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE :test2 AND "test3" = :test3 AND "test4" = :test4', 'test1 = 1 AND test2 LIKE :test2 AND "test3" = :test3 AND "test4" = :test4',
dict(test2=2, test3=3, test4=4) dict(test2=2, test3=3, test4=4),
) )
def test_format_where_clauses_with_where_op(): def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert OracleDB._format_where_clauses(where_clauses, where_op='OR') == ( assert OracleDB._format_where_clauses(where_clauses, where_op="OR") == (
'"test1" = :test1 OR "test2" = :test2', '"test1" = :test1 OR "test2" = :test2',
where_clauses where_clauses,
) )
@ -228,7 +253,7 @@ def test_add_where_clauses():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert OracleDB._add_where_clauses(sql, None, where_clauses) == ( assert OracleDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE "test1" = :test1 AND "test2" = :test2', sql + ' WHERE "test1" = :test1 AND "test2" = :test2',
where_clauses where_clauses,
) )
@ -238,26 +263,26 @@ def test_add_where_clauses_preserved_params():
params = dict(fake1=1) params = dict(fake1=1)
assert OracleDB._add_where_clauses(sql, params.copy(), where_clauses) == ( assert OracleDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE "test1" = :test1 AND "test2" = :test2', sql + ' WHERE "test1" = :test1 AND "test2" = :test2',
dict(**where_clauses, **params) dict(**where_clauses, **params),
) )
def test_add_where_clauses_with_op(): def test_add_where_clauses_with_op():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = ('test1=1', 'test2=2') where_clauses = ("test1=1", "test2=2")
assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + ' WHERE test1=1 OR test2=2', sql + " WHERE test1=1 OR test2=2",
{} {},
) )
def test_add_where_clauses_with_duplicated_field(): def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=:test1" sql = "UPDATE table SET test1=:test1"
params = dict(test1='new_value') params = dict(test1="new_value")
where_clauses = dict(test1='where_value') where_clauses = dict(test1="where_value")
assert OracleDB._add_where_clauses(sql, params, where_clauses) == ( assert OracleDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE "test1" = :test1_1', sql + ' WHERE "test1" = :test1_1',
dict(test1='new_value', test1_1='where_value') dict(test1="new_value", test1_1="where_value"),
) )
@ -269,74 +294,72 @@ def test_quote_table_name():
def test_insert(mocker, test_oracledb): def test_insert(mocker, test_oracledb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', 'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', values
values ),
)
) )
assert test_oracledb.insert('mytable', values) assert test_oracledb.insert("mytable", values)
def test_insert_just_try(mocker, test_oracledb): def test_insert_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.insert('mytable', dict(test1=1, test2=2), just_try=True) assert test_oracledb.insert("mytable", dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_oracledb): def test_update(mocker, test_oracledb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND "test4" = :test4', 'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND'
dict(**values, **where_clauses) ' "test4" = :test4',
) dict(**values, **where_clauses),
),
) )
assert test_oracledb.update('mytable', values, where_clauses) assert test_oracledb.update("mytable", values, where_clauses)
def test_update_just_try(mocker, test_oracledb): def test_update_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.update('mytable', dict(test1=1, test2=2), None, just_try=True) assert test_oracledb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_oracledb): def test_delete(mocker, test_oracledb):
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', 'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', where_clauses
where_clauses ),
)
) )
assert test_oracledb.delete('mytable', where_clauses) assert test_oracledb.delete("mytable", where_clauses)
def test_delete_just_try(mocker, test_oracledb): def test_delete_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.delete('mytable', None, just_try=True) assert test_oracledb.delete("mytable", None, just_try=True)
def test_truncate(mocker, test_oracledb): def test_truncate(mocker, test_oracledb):
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
) )
assert test_oracledb.truncate('mytable') assert test_oracledb.truncate("mytable")
def test_truncate_just_try(mocker, test_oracledb): def test_truncate_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try)
assert test_oracledb.truncate('mytable', just_try=True) assert test_oracledb.truncate("mytable", just_try=True)
def test_select(mocker, test_oracledb): def test_select(mocker, test_oracledb):
fields = ('field1', 'field2') fields = ("field1", "field2")
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
expected_return = [ expected_return = [
dict(field1=1, field2=2), dict(field1=1, field2=2),
@ -344,30 +367,30 @@ def test_select(mocker, test_oracledb):
] ]
order_by = "field1, DESC" order_by = "field1, DESC"
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSelect', "mylib.oracle.OracleDB.doSelect",
generate_mock_doSQL( generate_mock_doSQL(
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4 ORDER BY ' + order_by, 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4'
where_clauses, expected_return " ORDER BY " + order_by,
) where_clauses,
expected_return,
),
) )
assert test_oracledb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return assert (
test_oracledb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
)
def test_select_without_field_and_order_by(mocker, test_oracledb): def test_select_without_field_and_order_by(mocker, test_oracledb):
mocker.patch( mocker.patch("mylib.oracle.OracleDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"'))
'mylib.oracle.OracleDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM "mytable"'
)
)
assert test_oracledb.select('mytable') assert test_oracledb.select("mytable")
def test_select_just_try(mocker, test_oracledb): def test_select_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try)
assert test_oracledb.select('mytable', None, None, just_try=True) assert test_oracledb.select("mytable", None, None, just_try=True)
# #
# Tests on main methods # Tests on main methods
@ -376,17 +399,10 @@ def test_select_just_try(mocker, test_oracledb):
def test_connect(mocker, test_oracledb): def test_connect(mocker, test_oracledb):
expected_kwargs = dict( expected_kwargs = dict(
dsn=test_oracledb._dsn, dsn=test_oracledb._dsn, user=test_oracledb._user, password=test_oracledb._pwd
user=test_oracledb._user,
password=test_oracledb._pwd
) )
mocker.patch( mocker.patch("cx_Oracle.connect", generate_mock_args(expected_kwargs=expected_kwargs))
'cx_Oracle.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_oracledb.connect() assert test_oracledb.connect()
@ -400,50 +416,62 @@ def test_close_connected(fake_connected_oracledb):
def test_doSQL(fake_connected_oracledb): def test_doSQL(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table WHERE test1 = :test1' fake_connected_oracledb._conn.expected_sql = "DELETE FROM table WHERE test1 = :test1"
fake_connected_oracledb._conn.expected_params = dict(test1=1) fake_connected_oracledb._conn.expected_params = dict(test1=1)
fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params) fake_connected_oracledb.doSQL(
fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params
)
def test_doSQL_without_params(fake_connected_oracledb): def test_doSQL_without_params(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table' fake_connected_oracledb._conn.expected_sql = "DELETE FROM table"
fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql) fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_oracledb): def test_doSQL_just_try(fake_connected_just_try_oracledb):
assert fake_connected_just_try_oracledb.doSQL('DELETE FROM table') assert fake_connected_just_try_oracledb.doSQL("DELETE FROM table")
def test_doSQL_on_exception(fake_connected_oracledb): def test_doSQL_on_exception(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_exception = True fake_connected_oracledb._conn.expected_exception = True
assert fake_connected_oracledb.doSQL('DELETE FROM table') is False assert fake_connected_oracledb.doSQL("DELETE FROM table") is False
def test_doSelect(fake_connected_oracledb): def test_doSelect(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1' fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1"
fake_connected_oracledb._conn.expected_params = dict(test1=1) fake_connected_oracledb._conn.expected_params = dict(test1=1)
fake_connected_oracledb._conn.expected_return = [dict(test1=1)] fake_connected_oracledb._conn.expected_return = [dict(test1=1)]
assert fake_connected_oracledb.doSelect( assert (
fake_connected_oracledb.doSelect(
fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_sql,
fake_connected_oracledb._conn.expected_params) == fake_connected_oracledb._conn.expected_return fake_connected_oracledb._conn.expected_params,
)
== fake_connected_oracledb._conn.expected_return
)
def test_doSelect_without_params(fake_connected_oracledb): def test_doSelect_without_params(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table' fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table"
fake_connected_oracledb._conn.expected_return = [dict(test1=1)] fake_connected_oracledb._conn.expected_return = [dict(test1=1)]
assert fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql) == fake_connected_oracledb._conn.expected_return assert (
fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql)
== fake_connected_oracledb._conn.expected_return
)
def test_doSelect_on_exception(fake_connected_oracledb): def test_doSelect_on_exception(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_exception = True fake_connected_oracledb._conn.expected_exception = True
assert fake_connected_oracledb.doSelect('SELECT * FROM table') is False assert fake_connected_oracledb.doSelect("SELECT * FROM table") is False
def test_doSelect_just_try(fake_connected_just_try_oracledb): def test_doSelect_just_try(fake_connected_just_try_oracledb):
fake_connected_just_try_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1' fake_connected_just_try_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1"
fake_connected_just_try_oracledb._conn.expected_params = dict(test1=1) fake_connected_just_try_oracledb._conn.expected_params = dict(test1=1)
fake_connected_just_try_oracledb._conn.expected_return = [dict(test1=1)] fake_connected_just_try_oracledb._conn.expected_return = [dict(test1=1)]
assert fake_connected_just_try_oracledb.doSelect( assert (
fake_connected_just_try_oracledb.doSelect(
fake_connected_just_try_oracledb._conn.expected_sql, fake_connected_just_try_oracledb._conn.expected_sql,
fake_connected_just_try_oracledb._conn.expected_params fake_connected_just_try_oracledb._conn.expected_params,
) == fake_connected_just_try_oracledb._conn.expected_return )
== fake_connected_just_try_oracledb._conn.expected_return
)

View file

@ -10,7 +10,9 @@ from mylib.pgsql import PgDB
class FakePsycopg2Cursor: class FakePsycopg2Cursor:
"""Fake Psycopg2 cursor""" """Fake Psycopg2 cursor"""
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
self.expected_sql = expected_sql self.expected_sql = expected_sql
self.expected_params = expected_params self.expected_params = expected_params
self.expected_return = expected_return self.expected_return = expected_return
@ -19,13 +21,25 @@ class FakePsycopg2Cursor:
def execute(self, sql, params=None): def execute(self, sql, params=None):
if self.expected_exception: if self.expected_exception:
raise psycopg2.Error(f'{self}.execute({sql}, {params}): expected exception') raise psycopg2.Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith('select '): if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
return self.expected_return return self.expected_return
def fetchall(self): def fetchall(self):
@ -33,8 +47,8 @@ class FakePsycopg2Cursor:
def __repr__(self): def __repr__(self):
return ( return (
f'FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, ' f"FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, "
f'{self.expected_return}, {self.expected_just_try})' f"{self.expected_return}, {self.expected_just_try})"
) )
@ -52,8 +66,9 @@ class FakePsycopg2:
allowed_kwargs = dict(dbname=str, user=str, password=(str, None), host=str) allowed_kwargs = dict(dbname=str, user=str, password=(str, None), host=str)
for arg, value in kwargs.items(): for arg, value in kwargs.items():
assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"' assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"'
assert isinstance(value, allowed_kwargs[arg]), \ assert isinstance(
f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})' value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value) setattr(self, arg, value)
def close(self): def close(self):
@ -63,14 +78,16 @@ class FakePsycopg2:
self._check_just_try() self._check_just_try()
assert len(arg) == 1 and isinstance(arg[0], str) assert len(arg) == 1 and isinstance(arg[0], str)
if self.expected_exception: if self.expected_exception:
raise psycopg2.Error(f'set_client_encoding({arg[0]}): Expected exception') raise psycopg2.Error(f"set_client_encoding({arg[0]}): Expected exception")
return self.expected_return return self.expected_return
def cursor(self): def cursor(self):
return FakePsycopg2Cursor( return FakePsycopg2Cursor(
self.expected_sql, self.expected_params, self.expected_sql,
self.expected_return, self.expected_just_try or self.just_try, self.expected_params,
self.expected_exception self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
) )
def commit(self): def commit(self):
@ -98,19 +115,19 @@ def fake_psycopg2_connect_just_try(**kwargs):
@pytest.fixture @pytest.fixture
def test_pgdb(): def test_pgdb():
return PgDB('127.0.0.1', 'user', 'password', 'dbname') return PgDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_pgdb(mocker): def fake_pgdb(mocker):
mocker.patch('psycopg2.connect', fake_psycopg2_connect) mocker.patch("psycopg2.connect", fake_psycopg2_connect)
return PgDB('127.0.0.1', 'user', 'password', 'dbname') return PgDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_just_try_pgdb(mocker): def fake_just_try_pgdb(mocker):
mocker.patch('psycopg2.connect', fake_psycopg2_connect_just_try) mocker.patch("psycopg2.connect", fake_psycopg2_connect_just_try)
return PgDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True) return PgDB("127.0.0.1", "user", "password", "dbname", just_try=True)
@pytest.fixture @pytest.fixture
@ -125,13 +142,22 @@ def fake_connected_just_try_pgdb(fake_just_try_pgdb):
return fake_just_try_pgdb return fake_just_try_pgdb
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs): def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
return expected_return return expected_return
return mock_args return mock_args
@ -139,13 +165,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode" assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
return expected_return return expected_return
return mock_doSQL return mock_doSQL
@ -159,15 +194,11 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter(): def test_combine_params_with_to_add_parameter():
assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict( assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs(): def test_combine_params_with_kargs():
assert PgDB._combine_params(dict(test1=1), test2=2) == dict( assert PgDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter(): def test_combine_params_with_kargs_and_to_add_parameter():
@ -177,19 +208,16 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved(): def test_format_where_clauses_params_are_preserved():
args = ('test = test', dict(test1=1)) args = ("test = test", dict(test1=1))
assert PgDB._format_where_clauses(*args) == args assert PgDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw(): def test_format_where_clauses_raw():
assert PgDB._format_where_clauses('test = test') == (('test = test'), {}) assert PgDB._format_where_clauses("test = test") == ("test = test", {})
def test_format_where_clauses_tuple_clause_with_params(): def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ( where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2))
'test1 = %(test1)s AND test2 = %(test2)s',
dict(test1=1, test2=2)
)
assert PgDB._format_where_clauses(where_clauses) == where_clauses assert PgDB._format_where_clauses(where_clauses) == where_clauses
@ -197,27 +225,23 @@ def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert PgDB._format_where_clauses(where_clauses) == ( assert PgDB._format_where_clauses(where_clauses) == (
'"test1" = %(test1)s AND "test2" = %(test2)s', '"test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses where_clauses,
) )
def test_format_where_clauses_combined_types(): def test_format_where_clauses_combined_types():
where_clauses = ( where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4))
'test1 = 1',
('test2 LIKE %(test2)s', dict(test2=2)),
dict(test3=3, test4=4)
)
assert PgDB._format_where_clauses(where_clauses) == ( assert PgDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE %(test2)s AND "test3" = %(test3)s AND "test4" = %(test4)s', 'test1 = 1 AND test2 LIKE %(test2)s AND "test3" = %(test3)s AND "test4" = %(test4)s',
dict(test2=2, test3=3, test4=4) dict(test2=2, test3=3, test4=4),
) )
def test_format_where_clauses_with_where_op(): def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert PgDB._format_where_clauses(where_clauses, where_op='OR') == ( assert PgDB._format_where_clauses(where_clauses, where_op="OR") == (
'"test1" = %(test1)s OR "test2" = %(test2)s', '"test1" = %(test1)s OR "test2" = %(test2)s',
where_clauses where_clauses,
) )
@ -226,7 +250,7 @@ def test_add_where_clauses():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert PgDB._add_where_clauses(sql, None, where_clauses) == ( assert PgDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s', sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses where_clauses,
) )
@ -236,26 +260,26 @@ def test_add_where_clauses_preserved_params():
params = dict(fake1=1) params = dict(fake1=1)
assert PgDB._add_where_clauses(sql, params.copy(), where_clauses) == ( assert PgDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s', sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
dict(**where_clauses, **params) dict(**where_clauses, **params),
) )
def test_add_where_clauses_with_op(): def test_add_where_clauses_with_op():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = ('test1=1', 'test2=2') where_clauses = ("test1=1", "test2=2")
assert PgDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( assert PgDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + ' WHERE test1=1 OR test2=2', sql + " WHERE test1=1 OR test2=2",
{} {},
) )
def test_add_where_clauses_with_duplicated_field(): def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=%(test1)s" sql = "UPDATE table SET test1=%(test1)s"
params = dict(test1='new_value') params = dict(test1="new_value")
where_clauses = dict(test1='where_value') where_clauses = dict(test1="where_value")
assert PgDB._add_where_clauses(sql, params, where_clauses) == ( assert PgDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE "test1" = %(test1_1)s', sql + ' WHERE "test1" = %(test1_1)s',
dict(test1='new_value', test1_1='where_value') dict(test1="new_value", test1_1="where_value"),
) )
@ -267,74 +291,70 @@ def test_quote_table_name():
def test_insert(mocker, test_pgdb): def test_insert(mocker, test_pgdb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSQL', "mylib.pgsql.PgDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', 'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', values
values ),
)
) )
assert test_pgdb.insert('mytable', values) assert test_pgdb.insert("mytable", values)
def test_insert_just_try(mocker, test_pgdb): def test_insert_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.insert('mytable', dict(test1=1, test2=2), just_try=True) assert test_pgdb.insert("mytable", dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_pgdb): def test_update(mocker, test_pgdb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSQL', "mylib.pgsql.PgDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" = %(test3)s AND "test4" = %(test4)s', 'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" ='
dict(**values, **where_clauses) ' %(test3)s AND "test4" = %(test4)s',
) dict(**values, **where_clauses),
),
) )
assert test_pgdb.update('mytable', values, where_clauses) assert test_pgdb.update("mytable", values, where_clauses)
def test_update_just_try(mocker, test_pgdb): def test_update_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.update('mytable', dict(test1=1, test2=2), None, just_try=True) assert test_pgdb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_pgdb): def test_delete(mocker, test_pgdb):
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSQL', "mylib.pgsql.PgDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', 'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', where_clauses
where_clauses ),
)
) )
assert test_pgdb.delete('mytable', where_clauses) assert test_pgdb.delete("mytable", where_clauses)
def test_delete_just_try(mocker, test_pgdb): def test_delete_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.delete('mytable', None, just_try=True) assert test_pgdb.delete("mytable", None, just_try=True)
def test_truncate(mocker, test_pgdb): def test_truncate(mocker, test_pgdb):
mocker.patch( mocker.patch("mylib.pgsql.PgDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None))
'mylib.pgsql.PgDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
)
assert test_pgdb.truncate('mytable') assert test_pgdb.truncate("mytable")
def test_truncate_just_try(mocker, test_pgdb): def test_truncate_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try)
assert test_pgdb.truncate('mytable', just_try=True) assert test_pgdb.truncate("mytable", just_try=True)
def test_select(mocker, test_pgdb): def test_select(mocker, test_pgdb):
fields = ('field1', 'field2') fields = ("field1", "field2")
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
expected_return = [ expected_return = [
dict(field1=1, field2=2), dict(field1=1, field2=2),
@ -342,30 +362,28 @@ def test_select(mocker, test_pgdb):
] ]
order_by = "field1, DESC" order_by = "field1, DESC"
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSelect', "mylib.pgsql.PgDB.doSelect",
generate_mock_doSQL( generate_mock_doSQL(
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" = %(test4)s ORDER BY ' + order_by, 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" ='
where_clauses, expected_return " %(test4)s ORDER BY " + order_by,
) where_clauses,
expected_return,
),
) )
assert test_pgdb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return assert test_pgdb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_pgdb): def test_select_without_field_and_order_by(mocker, test_pgdb):
mocker.patch( mocker.patch("mylib.pgsql.PgDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"'))
'mylib.pgsql.PgDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM "mytable"'
)
)
assert test_pgdb.select('mytable') assert test_pgdb.select("mytable")
def test_select_just_try(mocker, test_pgdb): def test_select_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try)
assert test_pgdb.select('mytable', None, None, just_try=True) assert test_pgdb.select("mytable", None, None, just_try=True)
# #
# Tests on main methods # Tests on main methods
@ -374,18 +392,10 @@ def test_select_just_try(mocker, test_pgdb):
def test_connect(mocker, test_pgdb): def test_connect(mocker, test_pgdb):
expected_kwargs = dict( expected_kwargs = dict(
dbname=test_pgdb._db, dbname=test_pgdb._db, user=test_pgdb._user, host=test_pgdb._host, password=test_pgdb._pwd
user=test_pgdb._user,
host=test_pgdb._host,
password=test_pgdb._pwd
) )
mocker.patch( mocker.patch("psycopg2.connect", generate_mock_args(expected_kwargs=expected_kwargs))
'psycopg2.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_pgdb.connect() assert test_pgdb.connect()
@ -399,61 +409,74 @@ def test_close_connected(fake_connected_pgdb):
def test_setEncoding(fake_connected_pgdb): def test_setEncoding(fake_connected_pgdb):
assert fake_connected_pgdb.setEncoding('utf8') assert fake_connected_pgdb.setEncoding("utf8")
def test_setEncoding_not_connected(fake_pgdb): def test_setEncoding_not_connected(fake_pgdb):
assert fake_pgdb.setEncoding('utf8') is False assert fake_pgdb.setEncoding("utf8") is False
def test_setEncoding_on_exception(fake_connected_pgdb): def test_setEncoding_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.setEncoding('utf8') is False assert fake_connected_pgdb.setEncoding("utf8") is False
def test_doSQL(fake_connected_pgdb): def test_doSQL(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s' fake_connected_pgdb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s"
fake_connected_pgdb._conn.expected_params = dict(test1=1) fake_connected_pgdb._conn.expected_params = dict(test1=1)
fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) fake_connected_pgdb.doSQL(
fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params
)
def test_doSQL_without_params(fake_connected_pgdb): def test_doSQL_without_params(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table' fake_connected_pgdb._conn.expected_sql = "DELETE FROM table"
fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql) fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_pgdb): def test_doSQL_just_try(fake_connected_just_try_pgdb):
assert fake_connected_just_try_pgdb.doSQL('DELETE FROM table') assert fake_connected_just_try_pgdb.doSQL("DELETE FROM table")
def test_doSQL_on_exception(fake_connected_pgdb): def test_doSQL_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.doSQL('DELETE FROM table') is False assert fake_connected_pgdb.doSQL("DELETE FROM table") is False
def test_doSelect(fake_connected_pgdb): def test_doSelect(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_pgdb._conn.expected_params = dict(test1=1) fake_connected_pgdb._conn.expected_params = dict(test1=1)
fake_connected_pgdb._conn.expected_return = [dict(test1=1)] fake_connected_pgdb._conn.expected_return = [dict(test1=1)]
assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) == fake_connected_pgdb._conn.expected_return assert (
fake_connected_pgdb.doSelect(
fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params
)
== fake_connected_pgdb._conn.expected_return
)
def test_doSelect_without_params(fake_connected_pgdb): def test_doSelect_without_params(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table' fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table"
fake_connected_pgdb._conn.expected_return = [dict(test1=1)] fake_connected_pgdb._conn.expected_return = [dict(test1=1)]
assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql) == fake_connected_pgdb._conn.expected_return assert (
fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql)
== fake_connected_pgdb._conn.expected_return
)
def test_doSelect_on_exception(fake_connected_pgdb): def test_doSelect_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.doSelect('SELECT * FROM table') is False assert fake_connected_pgdb.doSelect("SELECT * FROM table") is False
def test_doSelect_just_try(fake_connected_just_try_pgdb): def test_doSelect_just_try(fake_connected_just_try_pgdb):
fake_connected_just_try_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_just_try_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_just_try_pgdb._conn.expected_params = dict(test1=1) fake_connected_just_try_pgdb._conn.expected_params = dict(test1=1)
fake_connected_just_try_pgdb._conn.expected_return = [dict(test1=1)] fake_connected_just_try_pgdb._conn.expected_return = [dict(test1=1)]
assert fake_connected_just_try_pgdb.doSelect( assert (
fake_connected_just_try_pgdb.doSelect(
fake_connected_just_try_pgdb._conn.expected_sql, fake_connected_just_try_pgdb._conn.expected_sql,
fake_connected_just_try_pgdb._conn.expected_params fake_connected_just_try_pgdb._conn.expected_params,
) == fake_connected_just_try_pgdb._conn.expected_return )
== fake_connected_just_try_pgdb._conn.expected_return
)

View file

@ -3,13 +3,14 @@
import datetime import datetime
import os import os
import pytest import pytest
from mylib.telltale import TelltaleFile from mylib.telltale import TelltaleFile
def test_create_telltale_file(tmp_path): def test_create_telltale_file(tmp_path):
filename = 'test' filename = "test"
file = TelltaleFile(filename=filename, dirpath=tmp_path) file = TelltaleFile(filename=filename, dirpath=tmp_path)
assert file.filename == filename assert file.filename == filename
assert file.dirpath == tmp_path assert file.dirpath == tmp_path
@ -24,15 +25,15 @@ def test_create_telltale_file(tmp_path):
def test_create_telltale_file_with_filepath_and_invalid_dirpath(): def test_create_telltale_file_with_filepath_and_invalid_dirpath():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
TelltaleFile(filepath='/tmp/test', dirpath='/var/tmp') TelltaleFile(filepath="/tmp/test", dirpath="/var/tmp")
def test_create_telltale_file_with_filepath_and_invalid_filename(): def test_create_telltale_file_with_filepath_and_invalid_filename():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
TelltaleFile(filepath='/tmp/test', filename='other') TelltaleFile(filepath="/tmp/test", filename="other")
def test_remove_telltale_file(tmp_path): def test_remove_telltale_file(tmp_path):
file = TelltaleFile(filename='test', dirpath=tmp_path) file = TelltaleFile(filename="test", dirpath=tmp_path)
file.update() file.update()
assert file.remove() assert file.remove()