Introduce pyupgrade,isort,black and configure pre-commit hooks to run all testing tools before commit
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Benjamin Renard 2023-01-16 12:56:12 +01:00
parent a83c3d635f
commit 62c3fadf96
34 changed files with 2356 additions and 2026 deletions

39
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,39 @@
# Pre-commit hooks to run tests and ensure code is cleaned.
# See https://pre-commit.com for more information
repos:
- repo: local
hooks:
- id: pytest
name: pytest
entry: python3 -m pytest tests
language: system
pass_filenames: false
always_run: true
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint --extension-pkg-whitelist=cx_Oracle
language: system
types: [python]
require_serial: true
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ['--max-line-length=100']
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: ['--keep-percent-format', '--py37-plus']
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
args: ['--target-version', 'py37', '--line-length', '100']
- repo: https://github.com/PyCQA/isort
rev: 5.11.4
hooks:
- id: isort
args: ['--profile', 'black', '--line-length', '100']

View file

@ -8,5 +8,8 @@ disable=invalid-name,
too-many-nested-blocks, too-many-nested-blocks,
too-many-instance-attributes, too-many-instance-attributes,
too-many-lines, too-many-lines,
line-too-long,
duplicate-code, duplicate-code,
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=100

View file

@ -49,6 +49,41 @@ Just run `python setup.py install`
To know how to use these libs, you can take a look on *mylib.scripts* content or in *tests* directory. To know how to use these libs, you can take a look on *mylib.scripts* content or in *tests* directory.
## Code Style
[pylint](https://pypi.org/project/pylint/) is used to check for errors and enforces a coding standard, using thoses parameters:
```bash
pylint --extension-pkg-whitelist=cx_Oracle
```
[flake8](https://pypi.org/project/flake8/) is also used to check for errors and enforces a coding standard, using thoses parameters:
```bash
flake8 --max-line-length=100
```
[black](https://pypi.org/project/black/) is used to format the code, using thoses parameters:
```bash
black --target-version py37 --line-length 100
```
[isort](https://pypi.org/project/isort/) is used to format the imports, using those parameter:
```bash
isort --profile black --line-length 100
```
[pyupgrade](https://pypi.org/project/pyupgrade/) is used to automatically upgrade syntax, using those parameters:
```bash
pyupgrade --keep-percent-format --py37-plus
```
**Note:** There is `.pre-commit-config.yaml` to use [pre-commit](https://pre-commit.com/) to automatically run these tools before commits. After cloning the repository, execute `pre-commit install` to install the git hook.
## Copyright ## Copyright
Copyright (c) 2013-2021 Benjamin Renard <brenard@zionetrix.net> Copyright (c) 2013-2021 Benjamin Renard <brenard@zionetrix.net>

View file

@ -6,12 +6,12 @@
def increment_prefix(prefix): def increment_prefix(prefix):
""" Increment the given prefix with two spaces """ """Increment the given prefix with two spaces"""
return f'{prefix if prefix else " "} ' return f'{prefix if prefix else " "} '
def pretty_format_value(value, encoding='utf8', prefix=None): def pretty_format_value(value, encoding="utf8", prefix=None):
""" Returned pretty formated value to display """ """Returned pretty formated value to display"""
if isinstance(value, dict): if isinstance(value, dict):
return pretty_format_dict(value, encoding=encoding, prefix=prefix) return pretty_format_dict(value, encoding=encoding, prefix=prefix)
if isinstance(value, list): if isinstance(value, list):
@ -22,10 +22,10 @@ def pretty_format_value(value, encoding='utf8', prefix=None):
return f"'{value}'" return f"'{value}'"
if value is None: if value is None:
return "None" return "None"
return f'{value} ({type(value)})' return f"{value} ({type(value)})"
def pretty_format_value_in_list(value, encoding='utf8', prefix=None): def pretty_format_value_in_list(value, encoding="utf8", prefix=None):
""" """
Returned pretty formated value to display in list Returned pretty formated value to display in list
@ -34,42 +34,31 @@ def pretty_format_value_in_list(value, encoding='utf8', prefix=None):
""" """
prefix = prefix if prefix else "" prefix = prefix if prefix else ""
value = pretty_format_value(value, encoding, prefix) value = pretty_format_value(value, encoding, prefix)
if '\n' in value: if "\n" in value:
inc_prefix = increment_prefix(prefix) inc_prefix = increment_prefix(prefix)
value = "\n" + "\n".join([ value = "\n" + "\n".join([inc_prefix + line for line in value.split("\n")])
inc_prefix + line
for line in value.split('\n')
])
return value return value
def pretty_format_dict(value, encoding='utf8', prefix=None): def pretty_format_dict(value, encoding="utf8", prefix=None):
""" Returned pretty formated dict to display """ """Returned pretty formated dict to display"""
prefix = prefix if prefix else "" prefix = prefix if prefix else ""
result = [] result = []
for key in sorted(value.keys()): for key in sorted(value.keys()):
result.append( result.append(
f'{prefix}- {key} : ' f"{prefix}- {key} : "
+ pretty_format_value_in_list( + pretty_format_value_in_list(value[key], encoding=encoding, prefix=prefix)
value[key],
encoding=encoding,
prefix=prefix
)
) )
return "\n".join(result) return "\n".join(result)
def pretty_format_list(row, encoding='utf8', prefix=None): def pretty_format_list(row, encoding="utf8", prefix=None):
""" Returned pretty formated list to display """ """Returned pretty formated list to display"""
prefix = prefix if prefix else "" prefix = prefix if prefix else ""
result = [] result = []
for idx, values in enumerate(row): for idx, values in enumerate(row):
result.append( result.append(
f'{prefix}- #{idx} : ' f"{prefix}- #{idx} : "
+ pretty_format_value_in_list( + pretty_format_value_in_list(values, encoding=encoding, prefix=prefix)
values,
encoding=encoding,
prefix=prefix
)
) )
return "\n".join(result) return "\n".join(result)

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,7 @@
# -*- coding: utf-8 -*-
""" Basic SQL DB client """ """ Basic SQL DB client """
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -12,8 +9,9 @@ log = logging.getLogger(__name__)
# Exceptions # Exceptions
# #
class DBException(Exception): class DBException(Exception):
""" That is the base exception class for all the other exceptions provided by this module. """ """That is the base exception class for all the other exceptions provided by this module."""
def __init__(self, error, *args, **kwargs): def __init__(self, error, *args, **kwargs):
for arg, value in kwargs.items(): for arg, value in kwargs.items():
@ -29,7 +27,8 @@ class DBNotImplemented(DBException, RuntimeError):
def __init__(self, method, class_name): def __init__(self, method, class_name):
super().__init__( super().__init__(
"The method {method} is not yet implemented in class {class_name}", "The method {method} is not yet implemented in class {class_name}",
method=method, class_name=class_name method=method,
class_name=class_name,
) )
@ -39,10 +38,7 @@ class DBFailToConnect(DBException, RuntimeError):
""" """
def __init__(self, uri): def __init__(self, uri):
super().__init__( super().__init__("An error occured during database connection ({uri})", uri=uri)
"An error occured during database connection ({uri})",
uri=uri
)
class DBDuplicatedSQLParameter(DBException, KeyError): class DBDuplicatedSQLParameter(DBException, KeyError):
@ -53,8 +49,7 @@ class DBDuplicatedSQLParameter(DBException, KeyError):
def __init__(self, parameter_name): def __init__(self, parameter_name):
super().__init__( super().__init__(
"Duplicated SQL parameter '{parameter_name}'", "Duplicated SQL parameter '{parameter_name}'", parameter_name=parameter_name
parameter_name=parameter_name
) )
@ -65,10 +60,7 @@ class DBUnsupportedWHEREClauses(DBException, TypeError):
""" """
def __init__(self, where_clauses): def __init__(self, where_clauses):
super().__init__( super().__init__("Unsupported WHERE clauses: {where_clauses}", where_clauses=where_clauses)
"Unsupported WHERE clauses: {where_clauses}",
where_clauses=where_clauses
)
class DBInvalidOrderByClause(DBException, TypeError): class DBInvalidOrderByClause(DBException, TypeError):
@ -79,13 +71,14 @@ class DBInvalidOrderByClause(DBException, TypeError):
def __init__(self, order_by): def __init__(self, order_by):
super().__init__( super().__init__(
"Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values (ordering field name and direction)", "Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values"
order_by=order_by " (ordering field name and direction)",
order_by=order_by,
) )
class DB: class DB:
""" Database client """ """Database client"""
just_try = False just_try = False
@ -93,14 +86,14 @@ class DB:
self.just_try = just_try self.just_try = just_try
self._conn = None self._conn = None
for arg, value in kwargs.items(): for arg, value in kwargs.items():
setattr(self, f'_{arg}', value) setattr(self, f"_{arg}", value)
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
""" Connect to DB server """ """Connect to DB server"""
raise DBNotImplemented('connect', self.__class__.__name__) raise DBNotImplemented("connect", self.__class__.__name__)
def close(self): def close(self):
""" Close connection with DB server (if opened) """ """Close connection with DB server (if opened)"""
if self._conn: if self._conn:
self._conn.close() self._conn.close()
self._conn = None self._conn = None
@ -110,12 +103,11 @@ class DB:
log.debug( log.debug(
'Run SQL query "%s" %s', 'Run SQL query "%s" %s',
sql, sql,
"with params = {0}".format( # pylint: disable=consider-using-f-string "with params = {}".format( # pylint: disable=consider-using-f-string
', '.join([ ", ".join([f"{key} = {value}" for key, value in params.items()])
f'{key} = {value}' if params
for key, value in params.items() else "without params"
]) if params else "without params" ),
)
) )
@staticmethod @staticmethod
@ -123,12 +115,11 @@ class DB:
log.exception( log.exception(
'Error during SQL query "%s" %s', 'Error during SQL query "%s" %s',
sql, sql,
"with params = {0}".format( # pylint: disable=consider-using-f-string "with params = {}".format( # pylint: disable=consider-using-f-string
', '.join([ ", ".join([f"{key} = {value}" for key, value in params.items()])
f'{key} = {value}' if params
for key, value in params.items() else "without params"
]) if params else "without params" ),
)
) )
def doSQL(self, sql, params=None): def doSQL(self, sql, params=None):
@ -141,7 +132,7 @@ class DB:
:return: True on success, False otherwise :return: True on success, False otherwise
:rtype: bool :rtype: bool
""" """
raise DBNotImplemented('doSQL', self.__class__.__name__) raise DBNotImplemented("doSQL", self.__class__.__name__)
def doSelect(self, sql, params=None): def doSelect(self, sql, params=None):
""" """
@ -153,7 +144,7 @@ class DB:
:return: List of selected rows as dict on success, False otherwise :return: List of selected rows as dict on success, False otherwise
:rtype: list, bool :rtype: list, bool
""" """
raise DBNotImplemented('doSelect', self.__class__.__name__) raise DBNotImplemented("doSelect", self.__class__.__name__)
# #
# SQL helpers # SQL helpers
@ -161,22 +152,20 @@ class DB:
@staticmethod @staticmethod
def _quote_table_name(table): def _quote_table_name(table):
""" Quote table name """ """Quote table name"""
return '"{0}"'.format( # pylint: disable=consider-using-f-string return '"{}"'.format( # pylint: disable=consider-using-f-string
'"."'.join( '"."'.join(table.split("."))
table.split('.')
)
) )
@staticmethod @staticmethod
def _quote_field_name(field): def _quote_field_name(field):
""" Quote table name """ """Quote table name"""
return f'"{field}"' return f'"{field}"'
@staticmethod @staticmethod
def format_param(param): def format_param(param):
""" Format SQL query parameter for prepared query """ """Format SQL query parameter for prepared query"""
return f'%({param})s' return f"%({param})s"
@classmethod @classmethod
def _combine_params(cls, params, to_add=None, **kwargs): def _combine_params(cls, params, to_add=None, **kwargs):
@ -201,7 +190,8 @@ class DB:
- a dict of WHERE clauses with field name as key and WHERE clause value as value - a dict of WHERE clauses with field name as key and WHERE clause value as value
- a list of any of previous valid WHERE clauses - a list of any of previous valid WHERE clauses
:param params: Dict of other already set SQL query parameters (optional) :param params: Dict of other already set SQL query parameters (optional)
:param where_op: SQL operator used to combine WHERE clauses together (optional, default: AND) :param where_op: SQL operator used to combine WHERE clauses together (optional, default:
AND)
:return: A tuple of two elements: raw SQL WHERE combined clauses and parameters on success :return: A tuple of two elements: raw SQL WHERE combined clauses and parameters on success
:rtype: string, bool :rtype: string, bool
@ -209,24 +199,27 @@ class DB:
if params is None: if params is None:
params = {} params = {}
if where_op is None: if where_op is None:
where_op = 'AND' where_op = "AND"
if isinstance(where_clauses, str): if isinstance(where_clauses, str):
return (where_clauses, params) return (where_clauses, params)
if isinstance(where_clauses, tuple) and len(where_clauses) == 2 and isinstance(where_clauses[1], dict): if (
isinstance(where_clauses, tuple)
and len(where_clauses) == 2
and isinstance(where_clauses[1], dict)
):
cls._combine_params(params, where_clauses[1]) cls._combine_params(params, where_clauses[1])
return (where_clauses[0], params) return (where_clauses[0], params)
if isinstance(where_clauses, (list, tuple)): if isinstance(where_clauses, (list, tuple)):
sql_where_clauses = [] sql_where_clauses = []
for where_clause in where_clauses: for where_clause in where_clauses:
sql2, params = cls._format_where_clauses(where_clause, params=params, where_op=where_op) sql2, params = cls._format_where_clauses(
where_clause, params=params, where_op=where_op
)
sql_where_clauses.append(sql2) sql_where_clauses.append(sql2)
return ( return (f" {where_op} ".join(sql_where_clauses), params)
f' {where_op} '.join(sql_where_clauses),
params
)
if isinstance(where_clauses, dict): if isinstance(where_clauses, dict):
sql_where_clauses = [] sql_where_clauses = []
@ -235,16 +228,13 @@ class DB:
if field in params: if field in params:
idx = 1 idx = 1
while param in params: while param in params:
param = f'{field}_{idx}' param = f"{field}_{idx}"
idx += 1 idx += 1
cls._combine_params(params, {param: value}) cls._combine_params(params, {param: value})
sql_where_clauses.append( sql_where_clauses.append(
f'{cls._quote_field_name(field)} = {cls.format_param(param)}' f"{cls._quote_field_name(field)} = {cls.format_param(param)}"
) )
return ( return (f" {where_op} ".join(sql_where_clauses), params)
f' {where_op} '.join(sql_where_clauses),
params
)
raise DBUnsupportedWHEREClauses(where_clauses) raise DBUnsupportedWHEREClauses(where_clauses)
@classmethod @classmethod
@ -255,29 +245,26 @@ class DB:
:param sql: The SQL query to complete :param sql: The SQL query to complete
:param params: The dict of parameters of the SQL query to complete :param params: The dict of parameters of the SQL query to complete
:param where_clauses: The WHERE clause (see _format_where_clauses()) :param where_clauses: The WHERE clause (see _format_where_clauses())
:param where_op: SQL operator used to combine WHERE clauses together (optional, default: see _format_where_clauses()) :param where_op: SQL operator used to combine WHERE clauses together (optional, default:
see _format_where_clauses())
:return: :return:
:rtype: A tuple of two elements: raw SQL WHERE combined clauses and parameters :rtype: A tuple of two elements: raw SQL WHERE combined clauses and parameters
""" """
if where_clauses: if where_clauses:
sql_where, params = cls._format_where_clauses(where_clauses, params=params, where_op=where_op) sql_where, params = cls._format_where_clauses(
where_clauses, params=params, where_op=where_op
)
sql += " WHERE " + sql_where sql += " WHERE " + sql_where
return (sql, params) return (sql, params)
def insert(self, table, values, just_try=False): def insert(self, table, values, just_try=False):
""" Run INSERT SQL query """ """Run INSERT SQL query"""
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
sql = 'INSERT INTO {0} ({1}) VALUES ({2})'.format( sql = "INSERT INTO {} ({}) VALUES ({})".format(
self._quote_table_name(table), self._quote_table_name(table),
', '.join([ ", ".join([self._quote_field_name(field) for field in values.keys()]),
self._quote_field_name(field) ", ".join([self.format_param(key) for key in values]),
for field in values.keys()
]),
", ".join([
self.format_param(key)
for key in values
])
) )
if just_try: if just_try:
@ -291,21 +278,20 @@ class DB:
return True return True
def update(self, table, values, where_clauses, where_op=None, just_try=False): def update(self, table, values, where_clauses, where_op=None, just_try=False):
""" Run UPDATE SQL query """ """Run UPDATE SQL query"""
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
sql = 'UPDATE {0} SET {1}'.format( sql = "UPDATE {} SET {}".format(
self._quote_table_name(table), self._quote_table_name(table),
", ".join([ ", ".join(
f'{self._quote_field_name(key)} = {self.format_param(key)}' [f"{self._quote_field_name(key)} = {self.format_param(key)}" for key in values]
for key in values ),
])
) )
params = values params = values
try: try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error('Fail to add WHERE clauses', exc_info=True) log.error("Fail to add WHERE clauses", exc_info=True)
return False return False
if just_try: if just_try:
@ -318,15 +304,15 @@ class DB:
return False return False
return True return True
def delete(self, table, where_clauses, where_op='AND', just_try=False): def delete(self, table, where_clauses, where_op="AND", just_try=False):
""" Run DELETE SQL query """ """Run DELETE SQL query"""
sql = f'DELETE FROM {self._quote_table_name(table)}' sql = f"DELETE FROM {self._quote_table_name(table)}"
params = {} params = {}
try: try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error('Fail to add WHERE clauses', exc_info=True) log.error("Fail to add WHERE clauses", exc_info=True)
return False return False
if just_try: if just_try:
@ -340,8 +326,8 @@ class DB:
return True return True
def truncate(self, table, just_try=False): def truncate(self, table, just_try=False):
""" Run TRUNCATE SQL query """ """Run TRUNCATE SQL query"""
sql = f'TRUNCATE TABLE {self._quote_table_name(table)}' sql = f"TRUNCATE TABLE {self._quote_table_name(table)}"
if just_try: if just_try:
log.debug("Just-try mode: execute TRUNCATE query: %s", sql) log.debug("Just-try mode: execute TRUNCATE query: %s", sql)
@ -353,33 +339,36 @@ class DB:
return False return False
return True return True
def select(self, table, where_clauses=None, fields=None, where_op='AND', order_by=None, just_try=False): def select(
""" Run SELECT SQL query """ self, table, where_clauses=None, fields=None, where_op="AND", order_by=None, just_try=False
):
"""Run SELECT SQL query"""
sql = "SELECT " sql = "SELECT "
if fields is None: if fields is None:
sql += "*" sql += "*"
elif isinstance(fields, str): elif isinstance(fields, str):
sql += f'{self._quote_field_name(fields)}' sql += f"{self._quote_field_name(fields)}"
else: else:
sql += ', '.join([self._quote_field_name(field) for field in fields]) sql += ", ".join([self._quote_field_name(field) for field in fields])
sql += f' FROM {self._quote_table_name(table)}' sql += f" FROM {self._quote_table_name(table)}"
params = {} params = {}
try: try:
sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op)
except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses):
log.error('Fail to add WHERE clauses', exc_info=True) log.error("Fail to add WHERE clauses", exc_info=True)
return False return False
if order_by: if order_by:
if isinstance(order_by, str): if isinstance(order_by, str):
sql += f' ORDER BY {order_by}' sql += f" ORDER BY {order_by}"
elif ( elif (
isinstance(order_by, (list, tuple)) and len(order_by) == 2 isinstance(order_by, (list, tuple))
and len(order_by) == 2
and isinstance(order_by[0], str) and isinstance(order_by[0], str)
and isinstance(order_by[1], str) and isinstance(order_by[1], str)
and order_by[1].upper() in ('ASC', 'UPPER') and order_by[1].upper() in ("ASC", "UPPER")
): ):
sql += f' ORDER BY "{order_by[0]}" {order_by[1].upper()}' sql += f' ORDER BY "{order_by[0]}" {order_by[1].upper()}'
else: else:

View file

@ -1,49 +1,51 @@
# -*- coding: utf-8 -*-
""" Email client to forge and send emails """ """ Email client to forge and send emails """
import email.utils
import logging import logging
import os import os
import smtplib import smtplib
import email.utils
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email.encoders import encode_base64 from email.encoders import encode_base64
from email.mime.base import MIMEBase
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from mako.template import Template as MakoTemplate from mako.template import Template as MakoTemplate
from mylib.config import ConfigurableObject from mylib.config import (
from mylib.config import BooleanOption BooleanOption,
from mylib.config import IntegerOption ConfigurableObject,
from mylib.config import PasswordOption IntegerOption,
from mylib.config import StringOption PasswordOption,
StringOption,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inheritance,too-many-instance-attributes class EmailClient(
ConfigurableObject
): # pylint: disable=useless-object-inheritance,too-many-instance-attributes
""" """
Email client Email client
This class abstract all interactions with the SMTP server. This class abstract all interactions with the SMTP server.
""" """
_config_name = 'email' _config_name = "email"
_config_comment = 'Email' _config_comment = "Email"
_defaults = { _defaults = {
'smtp_host': 'localhost', "smtp_host": "localhost",
'smtp_port': 25, "smtp_port": 25,
'smtp_ssl': False, "smtp_ssl": False,
'smtp_tls': False, "smtp_tls": False,
'smtp_user': None, "smtp_user": None,
'smtp_password': None, "smtp_password": None,
'smtp_debug': False, "smtp_debug": False,
'sender_name': 'No reply', "sender_name": "No reply",
'sender_email': 'noreply@localhost', "sender_email": "noreply@localhost",
'encoding': 'utf-8', "encoding": "utf-8",
'catch_all_addr': None, "catch_all_addr": None,
'just_try': False, "just_try": False,
} }
templates = {} templates = {}
@ -55,61 +57,107 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit
self.templates = templates if templates else {} self.templates = templates if templates else {}
# pylint: disable=arguments-differ,arguments-renamed # pylint: disable=arguments-differ,arguments-renamed
def configure(self, use_smtp=True, just_try=True, ** kwargs): def configure(self, use_smtp=True, just_try=True, **kwargs):
""" Configure options on registered mylib.Config object """ """Configure options on registered mylib.Config object"""
section = super().configure(**kwargs) section = super().configure(**kwargs)
if use_smtp: if use_smtp:
section.add_option( section.add_option(
StringOption, 'smtp_host', default=self._defaults['smtp_host'], StringOption,
comment='SMTP server hostname/IP address') "smtp_host",
default=self._defaults["smtp_host"],
comment="SMTP server hostname/IP address",
)
section.add_option( section.add_option(
IntegerOption, 'smtp_port', default=self._defaults['smtp_port'], IntegerOption,
comment='SMTP server port') "smtp_port",
default=self._defaults["smtp_port"],
comment="SMTP server port",
)
section.add_option( section.add_option(
BooleanOption, 'smtp_ssl', default=self._defaults['smtp_ssl'], BooleanOption,
comment='Use SSL on SMTP server connection') "smtp_ssl",
default=self._defaults["smtp_ssl"],
comment="Use SSL on SMTP server connection",
)
section.add_option( section.add_option(
BooleanOption, 'smtp_tls', default=self._defaults['smtp_tls'], BooleanOption,
comment='Use TLS on SMTP server connection') "smtp_tls",
default=self._defaults["smtp_tls"],
comment="Use TLS on SMTP server connection",
)
section.add_option( section.add_option(
StringOption, 'smtp_user', default=self._defaults['smtp_user'], StringOption,
comment='SMTP authentication username') "smtp_user",
default=self._defaults["smtp_user"],
comment="SMTP authentication username",
)
section.add_option( section.add_option(
PasswordOption, 'smtp_password', default=self._defaults['smtp_password'], PasswordOption,
"smtp_password",
default=self._defaults["smtp_password"],
comment='SMTP authentication password (set to "keyring" to use XDG keyring)', comment='SMTP authentication password (set to "keyring" to use XDG keyring)',
username_option='smtp_user', keyring_value='keyring') username_option="smtp_user",
keyring_value="keyring",
)
section.add_option( section.add_option(
BooleanOption, 'smtp_debug', default=self._defaults['smtp_debug'], BooleanOption,
comment='Enable SMTP debugging') "smtp_debug",
default=self._defaults["smtp_debug"],
comment="Enable SMTP debugging",
)
section.add_option( section.add_option(
StringOption, 'sender_name', default=self._defaults['sender_name'], StringOption,
comment='Sender name') "sender_name",
default=self._defaults["sender_name"],
comment="Sender name",
)
section.add_option( section.add_option(
StringOption, 'sender_email', default=self._defaults['sender_email'], StringOption,
comment='Sender email address') "sender_email",
default=self._defaults["sender_email"],
comment="Sender email address",
)
section.add_option( section.add_option(
StringOption, 'encoding', default=self._defaults['encoding'], StringOption, "encoding", default=self._defaults["encoding"], comment="Email encoding"
comment='Email encoding') )
section.add_option( section.add_option(
StringOption, 'catch_all_addr', default=self._defaults['catch_all_addr'], StringOption,
comment='Catch all sent emails to this specified email address') "catch_all_addr",
default=self._defaults["catch_all_addr"],
comment="Catch all sent emails to this specified email address",
)
if just_try: if just_try:
section.add_option( section.add_option(
BooleanOption, 'just_try', default=self._defaults['just_try'], BooleanOption,
comment='Just-try mode: do not really send emails') "just_try",
default=self._defaults["just_try"],
comment="Just-try mode: do not really send emails",
)
return section return section
def forge_message(self, rcpt_to, subject=None, html_body=None, text_body=None, # pylint: disable=too-many-arguments,too-many-locals def forge_message(
attachment_files=None, attachment_payloads=None, sender_name=None, self,
sender_email=None, encoding=None, template=None, **template_vars): rcpt_to,
subject=None,
html_body=None,
text_body=None, # pylint: disable=too-many-arguments,too-many-locals
attachment_files=None,
attachment_payloads=None,
sender_name=None,
sender_email=None,
encoding=None,
template=None,
**template_vars,
):
""" """
Forge a message Forge a message
:param rcpt_to: The recipient of the email. Could be a tuple(name, email) or just the email of the recipient. :param rcpt_to: The recipient of the email. Could be a tuple(name, email) or
just the email of the recipient.
:param subject: The subject of the email. :param subject: The subject of the email.
:param html_body: The HTML body of the email :param html_body: The HTML body of the email
:param text_body: The plain text body of the email :param text_body: The plain text body of the email
@ -122,64 +170,69 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit
All other parameters will be consider as template variables. All other parameters will be consider as template variables.
""" """
msg = MIMEMultipart('alternative') msg = MIMEMultipart("alternative")
msg['To'] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to msg["To"] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to
msg['From'] = email.utils.formataddr( msg["From"] = email.utils.formataddr(
( (
sender_name or self._get_option('sender_name'), sender_name or self._get_option("sender_name"),
sender_email or self._get_option('sender_email') sender_email or self._get_option("sender_email"),
) )
) )
if subject: if subject:
msg['Subject'] = subject.format(**template_vars) msg["Subject"] = subject.format(**template_vars)
msg['Date'] = email.utils.formatdate(None, True) msg["Date"] = email.utils.formatdate(None, True)
encoding = encoding if encoding else self._get_option('encoding') encoding = encoding if encoding else self._get_option("encoding")
if template: if template:
assert template in self.templates, f'Unknwon template {template}' assert template in self.templates, f"Unknwon template {template}"
# Handle subject from template # Handle subject from template
if not subject: if not subject:
assert self.templates[template].get('subject'), f'No subject defined in template {template}' assert self.templates[template].get(
msg['Subject'] = self.templates[template]['subject'].format(**template_vars) "subject"
), f"No subject defined in template {template}"
msg["Subject"] = self.templates[template]["subject"].format(**template_vars)
# Put HTML part in last one to prefered it # Put HTML part in last one to prefered it
parts = [] parts = []
if self.templates[template].get('text'): if self.templates[template].get("text"):
if isinstance(self.templates[template]['text'], MakoTemplate): if isinstance(self.templates[template]["text"], MakoTemplate):
parts.append((self.templates[template]['text'].render(**template_vars), 'plain')) parts.append(
(self.templates[template]["text"].render(**template_vars), "plain")
)
else: else:
parts.append((self.templates[template]['text'].format(**template_vars), 'plain')) parts.append(
if self.templates[template].get('html'): (self.templates[template]["text"].format(**template_vars), "plain")
if isinstance(self.templates[template]['html'], MakoTemplate): )
parts.append((self.templates[template]['html'].render(**template_vars), 'html')) if self.templates[template].get("html"):
if isinstance(self.templates[template]["html"], MakoTemplate):
parts.append((self.templates[template]["html"].render(**template_vars), "html"))
else: else:
parts.append((self.templates[template]['html'].format(**template_vars), 'html')) parts.append((self.templates[template]["html"].format(**template_vars), "html"))
for body, mime_type in parts: for body, mime_type in parts:
msg.attach(MIMEText(body.encode(encoding), mime_type, _charset=encoding)) msg.attach(MIMEText(body.encode(encoding), mime_type, _charset=encoding))
else: else:
assert subject, 'No subject provided' assert subject, "No subject provided"
if text_body: if text_body:
msg.attach(MIMEText(text_body.encode(encoding), 'plain', _charset=encoding)) msg.attach(MIMEText(text_body.encode(encoding), "plain", _charset=encoding))
if html_body: if html_body:
msg.attach(MIMEText(html_body.encode(encoding), 'html', _charset=encoding)) msg.attach(MIMEText(html_body.encode(encoding), "html", _charset=encoding))
if attachment_files: if attachment_files:
for filepath in attachment_files: for filepath in attachment_files:
with open(filepath, 'rb') as fp: with open(filepath, "rb") as fp:
part = MIMEBase('application', "octet-stream") part = MIMEBase("application", "octet-stream")
part.set_payload(fp.read()) part.set_payload(fp.read())
encode_base64(part) encode_base64(part)
part.add_header( part.add_header(
'Content-Disposition', "Content-Disposition",
f'attachment; filename="{os.path.basename(filepath)}"') f'attachment; filename="{os.path.basename(filepath)}"',
)
msg.attach(part) msg.attach(part)
if attachment_payloads: if attachment_payloads:
for filename, payload in attachment_payloads: for filename, payload in attachment_payloads:
part = MIMEBase('application', "octet-stream") part = MIMEBase("application", "octet-stream")
part.set_payload(payload) part.set_payload(payload)
encode_base64(part) encode_base64(part)
part.add_header( part.add_header("Content-Disposition", f'attachment; filename="{filename}"')
'Content-Disposition',
f'attachment; filename="{filename}"')
msg.attach(part) msg.attach(part)
return msg return msg
@ -192,200 +245,184 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit
:param msg: The message of this email (as MIMEBase or derivated classes) :param msg: The message of this email (as MIMEBase or derivated classes)
:param subject: The subject of the email (only if the message is not provided :param subject: The subject of the email (only if the message is not provided
using msg parameter) using msg parameter)
:param just_try: Enable just try mode (do not really send email, default: as defined on initialization) :param just_try: Enable just try mode (do not really send email, default: as defined on
initialization)
All other parameters will be consider as parameters to forge the message All other parameters will be consider as parameters to forge the message
(only if the message is not provided using msg parameter). (only if the message is not provided using msg parameter).
""" """
msg = msg if msg else self.forge_message(rcpt_to, subject, **forge_args) msg = msg if msg else self.forge_message(rcpt_to, subject, **forge_args)
if just_try or self._get_option('just_try'): if just_try or self._get_option("just_try"):
log.debug('Just-try mode: do not really send this email to %s (subject="%s")', rcpt_to, subject or msg.get('subject', 'No subject')) log.debug(
'Just-try mode: do not really send this email to %s (subject="%s")',
rcpt_to,
subject or msg.get("subject", "No subject"),
)
return True return True
catch_addr = self._get_option('catch_all_addr') catch_addr = self._get_option("catch_all_addr")
if catch_addr: if catch_addr:
log.debug('Catch email originaly send to %s to %s', rcpt_to, catch_addr) log.debug("Catch email originaly send to %s to %s", rcpt_to, catch_addr)
rcpt_to = catch_addr rcpt_to = catch_addr
smtp_host = self._get_option('smtp_host') smtp_host = self._get_option("smtp_host")
smtp_port = self._get_option('smtp_port') smtp_port = self._get_option("smtp_port")
try: try:
if self._get_option('smtp_ssl'): if self._get_option("smtp_ssl"):
logging.info("Establish SSL connection to server %s:%s", smtp_host, smtp_port) logging.info("Establish SSL connection to server %s:%s", smtp_host, smtp_port)
server = smtplib.SMTP_SSL(smtp_host, smtp_port) server = smtplib.SMTP_SSL(smtp_host, smtp_port)
else: else:
logging.info("Establish connection to server %s:%s", smtp_host, smtp_port) logging.info("Establish connection to server %s:%s", smtp_host, smtp_port)
server = smtplib.SMTP(smtp_host, smtp_port) server = smtplib.SMTP(smtp_host, smtp_port)
if self._get_option('smtp_tls'): if self._get_option("smtp_tls"):
logging.info('Start TLS on SMTP connection') logging.info("Start TLS on SMTP connection")
server.starttls() server.starttls()
except smtplib.SMTPException: except smtplib.SMTPException:
log.error('Error connecting to SMTP server %s:%s', smtp_host, smtp_port, exc_info=True) log.error("Error connecting to SMTP server %s:%s", smtp_host, smtp_port, exc_info=True)
return False return False
if self._get_option('smtp_debug'): if self._get_option("smtp_debug"):
server.set_debuglevel(True) server.set_debuglevel(True)
smtp_user = self._get_option('smtp_user') smtp_user = self._get_option("smtp_user")
smtp_password = self._get_option('smtp_password') smtp_password = self._get_option("smtp_password")
if smtp_user and smtp_password: if smtp_user and smtp_password:
try: try:
log.info('Try to authenticate on SMTP connection as %s', smtp_user) log.info("Try to authenticate on SMTP connection as %s", smtp_user)
server.login(smtp_user, smtp_password) server.login(smtp_user, smtp_password)
except smtplib.SMTPException: except smtplib.SMTPException:
log.error( log.error(
'Error authenticating on SMTP server %s:%s with user %s', "Error authenticating on SMTP server %s:%s with user %s",
smtp_host, smtp_port, smtp_user, exc_info=True) smtp_host,
smtp_port,
smtp_user,
exc_info=True,
)
return False return False
error = False error = False
try: try:
log.info('Sending email to %s', rcpt_to) log.info("Sending email to %s", rcpt_to)
server.sendmail( server.sendmail(
self._get_option('sender_email'), self._get_option("sender_email"),
[rcpt_to[1] if isinstance(rcpt_to, tuple) else rcpt_to], [rcpt_to[1] if isinstance(rcpt_to, tuple) else rcpt_to],
msg.as_string() msg.as_string(),
) )
except smtplib.SMTPException: except smtplib.SMTPException:
error = True error = True
log.error('Error sending email to %s', rcpt_to, exc_info=True) log.error("Error sending email to %s", rcpt_to, exc_info=True)
finally: finally:
server.quit() server.quit()
return not error return not error
if __name__ == '__main__': if __name__ == "__main__":
# Run tests # Run tests
import argparse
import datetime import datetime
import sys import sys
import argparse
# Options parser # Options parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-v', '--verbose', "-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode"
action="store_true",
dest="verbose",
help="Enable verbose mode"
) )
parser.add_argument( parser.add_argument(
'-d', '--debug', "-d", "--debug", action="store_true", dest="debug", help="Enable debug mode"
action="store_true",
dest="debug",
help="Enable debug mode"
) )
parser.add_argument( parser.add_argument(
'-l', '--log-file', "-l", "--log-file", action="store", type=str, dest="logfile", help="Log file path"
action="store",
type=str,
dest="logfile",
help="Log file path"
) )
parser.add_argument( parser.add_argument(
'-j', '--just-try', "-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode"
action="store_true",
dest="just_try",
help="Enable just-try mode"
) )
email_opts = parser.add_argument_group('Email options') email_opts = parser.add_argument_group("Email options")
email_opts.add_argument( email_opts.add_argument(
'-H', '--smtp-host', "-H", "--smtp-host", action="store", type=str, dest="email_smtp_host", help="SMTP host"
action="store",
type=str,
dest="email_smtp_host",
help="SMTP host"
) )
email_opts.add_argument( email_opts.add_argument(
'-P', '--smtp-port', "-P", "--smtp-port", action="store", type=int, dest="email_smtp_port", help="SMTP port"
action="store",
type=int,
dest="email_smtp_port",
help="SMTP port"
) )
email_opts.add_argument( email_opts.add_argument(
'-S', '--smtp-ssl', "-S", "--smtp-ssl", action="store_true", dest="email_smtp_ssl", help="Use SSL"
action="store_true",
dest="email_smtp_ssl",
help="Use SSL"
) )
email_opts.add_argument( email_opts.add_argument(
'-T', '--smtp-tls', "-T", "--smtp-tls", action="store_true", dest="email_smtp_tls", help="Use TLS"
action="store_true",
dest="email_smtp_tls",
help="Use TLS"
) )
email_opts.add_argument( email_opts.add_argument(
'-u', '--smtp-user', "-u", "--smtp-user", action="store", type=str, dest="email_smtp_user", help="SMTP username"
action="store",
type=str,
dest="email_smtp_user",
help="SMTP username"
) )
email_opts.add_argument( email_opts.add_argument(
'-p', '--smtp-password', "-p",
"--smtp-password",
action="store", action="store",
type=str, type=str,
dest="email_smtp_password", dest="email_smtp_password",
help="SMTP password" help="SMTP password",
) )
email_opts.add_argument( email_opts.add_argument(
'-D', '--smtp-debug', "-D",
"--smtp-debug",
action="store_true", action="store_true",
dest="email_smtp_debug", dest="email_smtp_debug",
help="Debug SMTP connection" help="Debug SMTP connection",
) )
email_opts.add_argument( email_opts.add_argument(
'-e', '--email-encoding', "-e",
"--email-encoding",
action="store", action="store",
type=str, type=str,
dest="email_encoding", dest="email_encoding",
help="SMTP encoding" help="SMTP encoding",
) )
email_opts.add_argument( email_opts.add_argument(
'-f', '--sender-name', "-f",
"--sender-name",
action="store", action="store",
type=str, type=str,
dest="email_sender_name", dest="email_sender_name",
help="Sender name" help="Sender name",
) )
email_opts.add_argument( email_opts.add_argument(
'-F', '--sender-email', "-F",
"--sender-email",
action="store", action="store",
type=str, type=str,
dest="email_sender_email", dest="email_sender_email",
help="Sender email" help="Sender email",
) )
email_opts.add_argument( email_opts.add_argument(
'-C', '--catch-all', "-C",
"--catch-all",
action="store", action="store",
type=str, type=str,
dest="email_catch_all", dest="email_catch_all",
help="Catch all sent email: specify catch recipient email address" help="Catch all sent email: specify catch recipient email address",
) )
test_opts = parser.add_argument_group('Test email options') test_opts = parser.add_argument_group("Test email options")
test_opts.add_argument( test_opts.add_argument(
'-t', '--to', "-t",
"--to",
action="store", action="store",
type=str, type=str,
dest="test_to", dest="test_to",
@ -393,7 +430,8 @@ if __name__ == '__main__':
) )
test_opts.add_argument( test_opts.add_argument(
'-m', '--mako', "-m",
"--mako",
action="store_true", action="store_true",
dest="test_mako", dest="test_mako",
help="Test mako templating", help="Test mako templating",
@ -402,11 +440,11 @@ if __name__ == '__main__':
options = parser.parse_args() options = parser.parse_args()
if not options.test_to: if not options.test_to:
parser.error('You must specify test email recipient using -t/--to parameter') parser.error("You must specify test email recipient using -t/--to parameter")
sys.exit(1) sys.exit(1)
# Initialize logs # Initialize logs
logformat = '%(asctime)s - Test EmailClient - %(levelname)s - %(message)s' logformat = "%(asctime)s - Test EmailClient - %(levelname)s - %(message)s"
if options.debug: if options.debug:
loglevel = logging.DEBUG loglevel = logging.DEBUG
elif options.verbose: elif options.verbose:
@ -421,9 +459,10 @@ if __name__ == '__main__':
if options.email_smtp_user and not options.email_smtp_password: if options.email_smtp_user and not options.email_smtp_password:
import getpass import getpass
options.email_smtp_password = getpass.getpass('Please enter SMTP password: ')
logging.info('Initialize Email client') options.email_smtp_password = getpass.getpass("Please enter SMTP password: ")
logging.info("Initialize Email client")
email_client = EmailClient( email_client = EmailClient(
smtp_host=options.email_smtp_host, smtp_host=options.email_smtp_host,
smtp_port=options.email_smtp_port, smtp_port=options.email_smtp_port,
@ -441,20 +480,24 @@ if __name__ == '__main__':
test=dict( test=dict(
subject="Test email", subject="Test email",
text=( text=(
"Just a test email sent at {sent_date}." if not options.test_mako else "Just a test email sent at {sent_date}."
MakoTemplate("Just a test email sent at ${sent_date}.") if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
), ),
html=( html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else "<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>") if not options.test_mako
) else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
)
),
) )
) ),
) )
logging.info('Send a test email to %s', options.test_to) logging.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
logging.info('Test email sent') logging.info("Test email sent")
sys.exit(0) sys.exit(0)
logging.error('Fail to send test email') logging.error("Fail to send test email")
sys.exit(1) sys.exit(1)

File diff suppressed because it is too large Load diff

View file

@ -48,19 +48,18 @@ Return format :
import logging import logging
import re import re
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def clean_value(value): def clean_value(value):
""" Clean value as encoded string """ """Clean value as encoded string"""
if isinstance(value, int): if isinstance(value, int):
value = str(value) value = str(value)
return value return value
def get_values(dst, dst_key, src, m): def get_values(dst, dst_key, src, m):
""" Extract sources values """ """Extract sources values"""
values = [] values = []
if "other_key" in m: if "other_key" in m:
if m["other_key"] in dst: if m["other_key"] in dst:

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" MySQL client """ """ MySQL client """
import logging import logging
@ -8,15 +6,13 @@ import sys
import MySQLdb import MySQLdb
from MySQLdb._exceptions import Error from MySQLdb._exceptions import Error
from mylib.db import DB from mylib.db import DB, DBFailToConnect
from mylib.db import DBFailToConnect
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class MyDB(DB): class MyDB(DB):
""" MySQL client """ """MySQL client"""
_host = None _host = None
_user = None _user = None
@ -28,25 +24,33 @@ class MyDB(DB):
self._user = user self._user = user
self._pwd = pwd self._pwd = pwd
self._db = db self._db = db
self._charset = charset if charset else 'utf8' self._charset = charset if charset else "utf8"
super().__init__(**kwargs) super().__init__(**kwargs)
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
""" Connect to MySQL server """ """Connect to MySQL server"""
if self._conn is None: if self._conn is None:
try: try:
self._conn = MySQLdb.connect( self._conn = MySQLdb.connect(
host=self._host, user=self._user, passwd=self._pwd, host=self._host,
db=self._db, charset=self._charset, use_unicode=True) user=self._user,
passwd=self._pwd,
db=self._db,
charset=self._charset,
use_unicode=True,
)
except Error as err: except Error as err:
log.fatal( log.fatal(
'An error occured during MySQL database connection (%s@%s:%s).', "An error occured during MySQL database connection (%s@%s:%s).",
self._user, self._host, self._db, exc_info=1 self._user,
self._host,
self._db,
exc_info=1,
) )
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)
else: else:
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err
return True return True
def doSQL(self, sql, params=None): def doSQL(self, sql, params=None):
@ -88,10 +92,7 @@ class MyDB(DB):
cursor = self._conn.cursor() cursor = self._conn.cursor()
cursor.execute(sql, params) cursor.execute(sql, params)
return [ return [
dict( {field[0]: row[idx] for idx, field in enumerate(cursor.description)}
(field[0], row[idx])
for idx, field in enumerate(cursor.description)
)
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
except Error: except Error:
@ -100,14 +101,12 @@ class MyDB(DB):
@staticmethod @staticmethod
def _quote_table_name(table): def _quote_table_name(table):
""" Quote table name """ """Quote table name"""
return '`{0}`'.format( # pylint: disable=consider-using-f-string return "`{}`".format( # pylint: disable=consider-using-f-string
'`.`'.join( "`.`".join(table.split("."))
table.split('.')
)
) )
@staticmethod @staticmethod
def _quote_field_name(field): def _quote_field_name(field):
""" Quote table name """ """Quote table name"""
return f'`{field}`' return f"`{field}`"

View file

@ -1,22 +1,20 @@
# -*- coding: utf-8 -*-
""" Opening hours helpers """ """ Opening hours helpers """
import datetime import datetime
import logging
import re import re
import time import time
import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
week_days = ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi', 'samedi', 'dimanche'] week_days = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"]
date_format = '%d/%m/%Y' date_format = "%d/%m/%Y"
date_pattern = re.compile('^([0-9]{2})/([0-9]{2})/([0-9]{4})$') date_pattern = re.compile("^([0-9]{2})/([0-9]{2})/([0-9]{4})$")
time_pattern = re.compile('^([0-9]{1,2})h([0-9]{2})?$') time_pattern = re.compile("^([0-9]{1,2})h([0-9]{2})?$")
def easter_date(year): def easter_date(year):
""" Compute easter date for the specified year """ """Compute easter date for the specified year"""
a = year // 100 a = year // 100
b = year % 100 b = year % 100
c = (3 * (a + 25)) // 4 c = (3 * (a + 25)) // 4
@ -36,30 +34,30 @@ def easter_date(year):
def nonworking_french_public_days_of_the_year(year=None): def nonworking_french_public_days_of_the_year(year=None):
""" Compute dict of nonworking french public days for the specified year """ """Compute dict of nonworking french public days for the specified year"""
if year is None: if year is None:
year = datetime.date.today().year year = datetime.date.today().year
dp = easter_date(year) dp = easter_date(year)
return { return {
'1janvier': datetime.date(year, 1, 1), "1janvier": datetime.date(year, 1, 1),
'paques': dp, "paques": dp,
'lundi_paques': (dp + datetime.timedelta(1)), "lundi_paques": (dp + datetime.timedelta(1)),
'1mai': datetime.date(year, 5, 1), "1mai": datetime.date(year, 5, 1),
'8mai': datetime.date(year, 5, 8), "8mai": datetime.date(year, 5, 8),
'jeudi_ascension': (dp + datetime.timedelta(39)), "jeudi_ascension": (dp + datetime.timedelta(39)),
'pentecote': (dp + datetime.timedelta(49)), "pentecote": (dp + datetime.timedelta(49)),
'lundi_pentecote': (dp + datetime.timedelta(50)), "lundi_pentecote": (dp + datetime.timedelta(50)),
'14juillet': datetime.date(year, 7, 14), "14juillet": datetime.date(year, 7, 14),
'15aout': datetime.date(year, 8, 15), "15aout": datetime.date(year, 8, 15),
'1novembre': datetime.date(year, 11, 1), "1novembre": datetime.date(year, 11, 1),
'11novembre': datetime.date(year, 11, 11), "11novembre": datetime.date(year, 11, 11),
'noel': datetime.date(year, 12, 25), "noel": datetime.date(year, 12, 25),
'saint_etienne': datetime.date(year, 12, 26), "saint_etienne": datetime.date(year, 12, 26),
} }
def parse_exceptional_closures(values): def parse_exceptional_closures(values):
""" Parse exceptional closures values """ """Parse exceptional closures values"""
exceptional_closures = [] exceptional_closures = []
for value in values: for value in values:
days = [] days = []
@ -68,7 +66,7 @@ def parse_exceptional_closures(values):
for word in words: for word in words:
if not word: if not word:
continue continue
parts = word.split('-') parts = word.split("-")
if len(parts) == 1: if len(parts) == 1:
# ex: 31/02/2017 # ex: 31/02/2017
ptime = time.strptime(word, date_format) ptime = time.strptime(word, date_format)
@ -82,7 +80,7 @@ def parse_exceptional_closures(values):
pstart = time.strptime(parts[0], date_format) pstart = time.strptime(parts[0], date_format)
pstop = time.strptime(parts[1], date_format) pstop = time.strptime(parts[1], date_format)
if pstop <= pstart: if pstop <= pstart:
raise ValueError(f'Day {parts[1]} <= {parts[0]}') raise ValueError(f"Day {parts[1]} <= {parts[0]}")
date = datetime.date(pstart.tm_year, pstart.tm_mon, pstart.tm_mday) date = datetime.date(pstart.tm_year, pstart.tm_mon, pstart.tm_mday)
stop_date = datetime.date(pstop.tm_year, pstop.tm_mon, pstop.tm_mday) stop_date = datetime.date(pstop.tm_year, pstop.tm_mon, pstop.tm_mday)
@ -99,18 +97,18 @@ def parse_exceptional_closures(values):
hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0)) hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0))
hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0)) hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0))
if hstop <= hstart: if hstop <= hstart:
raise ValueError(f'Time {parts[1]} <= {parts[0]}') raise ValueError(f"Time {parts[1]} <= {parts[0]}")
hours_periods.append({'start': hstart, 'stop': hstop}) hours_periods.append({"start": hstart, "stop": hstop})
else: else:
raise ValueError(f'Invalid number of part in this word: "{word}"') raise ValueError(f'Invalid number of part in this word: "{word}"')
if not days: if not days:
raise ValueError(f'No days found in value "{value}"') raise ValueError(f'No days found in value "{value}"')
exceptional_closures.append({'days': days, 'hours_periods': hours_periods}) exceptional_closures.append({"days": days, "hours_periods": hours_periods})
return exceptional_closures return exceptional_closures
def parse_normal_opening_hours(values): def parse_normal_opening_hours(values):
""" Parse normal opening hours """ """Parse normal opening hours"""
normal_opening_hours = [] normal_opening_hours = []
for value in values: for value in values:
days = [] days = []
@ -119,7 +117,7 @@ def parse_normal_opening_hours(values):
for word in words: for word in words:
if not word: if not word:
continue continue
parts = word.split('-') parts = word.split("-")
if len(parts) == 1: if len(parts) == 1:
# ex: jeudi # ex: jeudi
if word not in week_days: if word not in week_days:
@ -150,40 +148,51 @@ def parse_normal_opening_hours(values):
hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0)) hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0))
hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0)) hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0))
if hstop <= hstart: if hstop <= hstart:
raise ValueError(f'Time {parts[1]} <= {parts[0]}') raise ValueError(f"Time {parts[1]} <= {parts[0]}")
hours_periods.append({'start': hstart, 'stop': hstop}) hours_periods.append({"start": hstart, "stop": hstop})
else: else:
raise ValueError(f'Invalid number of part in this word: "{word}"') raise ValueError(f'Invalid number of part in this word: "{word}"')
if not days and not hours_periods: if not days and not hours_periods:
raise ValueError(f'No days or hours period found in this value: "{value}"') raise ValueError(f'No days or hours period found in this value: "{value}"')
normal_opening_hours.append({'days': days, 'hours_periods': hours_periods}) normal_opening_hours.append({"days": days, "hours_periods": hours_periods})
return normal_opening_hours return normal_opening_hours
def is_closed( def is_closed(
normal_opening_hours_values=None, exceptional_closures_values=None, normal_opening_hours_values=None,
nonworking_public_holidays_values=None, exceptional_closure_on_nonworking_public_days=False, exceptional_closures_values=None,
when=None, on_error='raise' nonworking_public_holidays_values=None,
exceptional_closure_on_nonworking_public_days=False,
when=None,
on_error="raise",
): ):
""" Check if closed """ """Check if closed"""
if not when: if not when:
when = datetime.datetime.now() when = datetime.datetime.now()
when_date = when.date() when_date = when.date()
when_time = when.time() when_time = when.time()
when_weekday = week_days[when.timetuple().tm_wday] when_weekday = week_days[when.timetuple().tm_wday]
on_error_result = None on_error_result = None
if on_error == 'closed': if on_error == "closed":
on_error_result = { on_error_result = {
'closed': True, 'exceptional_closure': False, "closed": True,
'exceptional_closure_all_day': False} "exceptional_closure": False,
elif on_error == 'opened': "exceptional_closure_all_day": False,
}
elif on_error == "opened":
on_error_result = { on_error_result = {
'closed': False, 'exceptional_closure': False, "closed": False,
'exceptional_closure_all_day': False} "exceptional_closure": False,
"exceptional_closure_all_day": False,
}
log.debug( log.debug(
"When = %s => date = %s / time = %s / week day = %s", "When = %s => date = %s / time = %s / week day = %s",
when, when_date, when_time, when_weekday) when,
when_date,
when_time,
when_weekday,
)
if nonworking_public_holidays_values: if nonworking_public_holidays_values:
log.debug("Nonworking public holidays: %s", nonworking_public_holidays_values) log.debug("Nonworking public holidays: %s", nonworking_public_holidays_values)
nonworking_days = nonworking_french_public_days_of_the_year() nonworking_days = nonworking_french_public_days_of_the_year()
@ -191,65 +200,69 @@ def is_closed(
if day in nonworking_days and when_date == nonworking_days[day]: if day in nonworking_days and when_date == nonworking_days[day]:
log.debug("Non working day: %s", day) log.debug("Non working day: %s", day)
return { return {
'closed': True, "closed": True,
'exceptional_closure': exceptional_closure_on_nonworking_public_days, "exceptional_closure": exceptional_closure_on_nonworking_public_days,
'exceptional_closure_all_day': exceptional_closure_on_nonworking_public_days "exceptional_closure_all_day": exceptional_closure_on_nonworking_public_days,
} }
if exceptional_closures_values: if exceptional_closures_values:
try: try:
exceptional_closures = parse_exceptional_closures(exceptional_closures_values) exceptional_closures = parse_exceptional_closures(exceptional_closures_values)
log.debug('Exceptional closures: %s', exceptional_closures) log.debug("Exceptional closures: %s", exceptional_closures)
except ValueError as e: except ValueError as e:
log.error("Fail to parse exceptional closures, consider as closed", exc_info=True) log.error("Fail to parse exceptional closures, consider as closed", exc_info=True)
if on_error_result is None: if on_error_result is None:
raise e from e raise e from e
return on_error_result return on_error_result
for cl in exceptional_closures: for cl in exceptional_closures:
if when_date not in cl['days']: if when_date not in cl["days"]:
log.debug("when_date (%s) no in days (%s)", when_date, cl['days']) log.debug("when_date (%s) no in days (%s)", when_date, cl["days"])
continue continue
if not cl['hours_periods']: if not cl["hours_periods"]:
# All day exceptional closure # All day exceptional closure
return { return {
'closed': True, 'exceptional_closure': True, "closed": True,
'exceptional_closure_all_day': True} "exceptional_closure": True,
for hp in cl['hours_periods']: "exceptional_closure_all_day": True,
if hp['start'] <= when_time <= hp['stop']: }
for hp in cl["hours_periods"]:
if hp["start"] <= when_time <= hp["stop"]:
return { return {
'closed': True, 'exceptional_closure': True, "closed": True,
'exceptional_closure_all_day': False} "exceptional_closure": True,
"exceptional_closure_all_day": False,
}
if normal_opening_hours_values: if normal_opening_hours_values:
try: try:
normal_opening_hours = parse_normal_opening_hours(normal_opening_hours_values) normal_opening_hours = parse_normal_opening_hours(normal_opening_hours_values)
log.debug('Normal opening hours: %s', normal_opening_hours) log.debug("Normal opening hours: %s", normal_opening_hours)
except ValueError as e: # pylint: disable=broad-except except ValueError as e: # pylint: disable=broad-except
log.error("Fail to parse normal opening hours, consider as closed", exc_info=True) log.error("Fail to parse normal opening hours, consider as closed", exc_info=True)
if on_error_result is None: if on_error_result is None:
raise e from e raise e from e
return on_error_result return on_error_result
for oh in normal_opening_hours: for oh in normal_opening_hours:
if oh['days'] and when_weekday not in oh['days']: if oh["days"] and when_weekday not in oh["days"]:
log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh['days']) log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh["days"])
continue continue
if not oh['hours_periods']: if not oh["hours_periods"]:
# All day opened # All day opened
return { return {
'closed': False, 'exceptional_closure': False, "closed": False,
'exceptional_closure_all_day': False} "exceptional_closure": False,
for hp in oh['hours_periods']: "exceptional_closure_all_day": False,
if hp['start'] <= when_time <= hp['stop']: }
for hp in oh["hours_periods"]:
if hp["start"] <= when_time <= hp["stop"]:
return { return {
'closed': False, 'exceptional_closure': False, "closed": False,
'exceptional_closure_all_day': False} "exceptional_closure": False,
"exceptional_closure_all_day": False,
}
log.debug("Not in normal opening hours => closed") log.debug("Not in normal opening hours => closed")
return { return {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True, 'exceptional_closure': False,
'exceptional_closure_all_day': False}
# Not a nonworking day, not during exceptional closure and no normal opening # Not a nonworking day, not during exceptional closure and no normal opening
# hours defined => Opened # hours defined => Opened
return { return {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': False, 'exceptional_closure': False,
'exceptional_closure_all_day': False}

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" Oracle client """ """ Oracle client """
import logging import logging
@ -7,14 +5,13 @@ import sys
import cx_Oracle import cx_Oracle
from mylib.db import DB from mylib.db import DB, DBFailToConnect
from mylib.db import DBFailToConnect
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class OracleDB(DB): class OracleDB(DB):
""" Oracle client """ """Oracle client"""
_dsn = None _dsn = None
_user = None _user = None
@ -27,24 +24,22 @@ class OracleDB(DB):
super().__init__(**kwargs) super().__init__(**kwargs)
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
""" Connect to Oracle server """ """Connect to Oracle server"""
if self._conn is None: if self._conn is None:
log.info('Connect on Oracle server with DSN %s as %s', self._dsn, self._user) log.info("Connect on Oracle server with DSN %s as %s", self._dsn, self._user)
try: try:
self._conn = cx_Oracle.connect( self._conn = cx_Oracle.connect(user=self._user, password=self._pwd, dsn=self._dsn)
user=self._user,
password=self._pwd,
dsn=self._dsn
)
except cx_Oracle.Error as err: except cx_Oracle.Error as err:
log.fatal( log.fatal(
'An error occured during Oracle database connection (%s@%s).', "An error occured during Oracle database connection (%s@%s).",
self._user, self._dsn, exc_info=1 self._user,
self._dsn,
exc_info=1,
) )
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)
else: else:
raise DBFailToConnect(f'{self._user}@{self._dsn}') from err raise DBFailToConnect(f"{self._user}@{self._dsn}") from err
return True return True
def doSQL(self, sql, params=None): def doSQL(self, sql, params=None):
@ -107,5 +102,5 @@ class OracleDB(DB):
@staticmethod @staticmethod
def format_param(param): def format_param(param):
""" Format SQL query parameter for prepared query """ """Format SQL query parameter for prepared query"""
return f':{param}' return f":{param}"

View file

@ -1,10 +1,8 @@
# coding: utf8
""" Progress bar """ """ Progress bar """
import logging import logging
import progressbar
import progressbar
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -25,15 +23,15 @@ class Pbar: # pylint: disable=useless-object-inheritance
self.__count = 0 self.__count = 0
self.__pbar = progressbar.ProgressBar( self.__pbar = progressbar.ProgressBar(
widgets=[ widgets=[
name + ': ', name + ": ",
progressbar.Percentage(), progressbar.Percentage(),
' ', " ",
progressbar.Bar(), progressbar.Bar(),
' ', " ",
progressbar.SimpleProgress(), progressbar.SimpleProgress(),
progressbar.ETA() progressbar.ETA(),
], ],
maxval=maxval maxval=maxval,
).start() ).start()
else: else:
log.info(name) log.info(name)
@ -49,6 +47,6 @@ class Pbar: # pylint: disable=useless-object-inheritance
self.__pbar.update(self.__count) self.__pbar.update(self.__count)
def finish(self): def finish(self):
""" Finish the progress bar """ """Finish the progress bar"""
if self.__pbar: if self.__pbar:
self.__pbar.finish() self.__pbar.finish()

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" PostgreSQL client """ """ PostgreSQL client """
import datetime import datetime
@ -14,15 +12,15 @@ log = logging.getLogger(__name__)
class PgDB(DB): class PgDB(DB):
""" PostgreSQL client """ """PostgreSQL client"""
_host = None _host = None
_user = None _user = None
_pwd = None _pwd = None
_db = None _db = None
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
datetime_format = '%Y-%m-%d %H:%M:%S' datetime_format = "%Y-%m-%d %H:%M:%S"
def __init__(self, host, user, pwd, db, **kwargs): def __init__(self, host, user, pwd, db, **kwargs):
self._host = host self._host = host
@ -32,37 +30,40 @@ class PgDB(DB):
super().__init__(**kwargs) super().__init__(**kwargs)
def connect(self, exit_on_error=True): def connect(self, exit_on_error=True):
""" Connect to PostgreSQL server """ """Connect to PostgreSQL server"""
if self._conn is None: if self._conn is None:
try: try:
log.info( log.info(
'Connect on PostgreSQL server %s as %s on database %s', "Connect on PostgreSQL server %s as %s on database %s",
self._host, self._user, self._db) self._host,
self._user,
self._db,
)
self._conn = psycopg2.connect( self._conn = psycopg2.connect(
dbname=self._db, dbname=self._db, user=self._user, host=self._host, password=self._pwd
user=self._user,
host=self._host,
password=self._pwd
) )
except psycopg2.Error as err: except psycopg2.Error as err:
log.fatal( log.fatal(
'An error occured during Postgresql database connection (%s@%s, database=%s).', "An error occured during Postgresql database connection (%s@%s, database=%s).",
self._user, self._host, self._db, exc_info=1 self._user,
self._host,
self._db,
exc_info=1,
) )
if exit_on_error: if exit_on_error:
sys.exit(1) sys.exit(1)
else: else:
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err
return True return True
def close(self): def close(self):
""" Close connection with PostgreSQL server (if opened) """ """Close connection with PostgreSQL server (if opened)"""
if self._conn: if self._conn:
self._conn.close() self._conn.close()
self._conn = None self._conn = None
def setEncoding(self, enc): def setEncoding(self, enc):
""" Set connection encoding """ """Set connection encoding"""
if self._conn: if self._conn:
try: try:
self._conn.set_client_encoding(enc) self._conn.set_client_encoding(enc)
@ -70,7 +71,8 @@ class PgDB(DB):
except psycopg2.Error: except psycopg2.Error:
log.error( log.error(
'An error occured setting Postgresql database connection encoding to "%s"', 'An error occured setting Postgresql database connection encoding to "%s"',
enc, exc_info=1 enc,
exc_info=1,
) )
return False return False
@ -124,10 +126,7 @@ class PgDB(DB):
@staticmethod @staticmethod
def _map_row_fields_by_index(fields, row): def _map_row_fields_by_index(fields, row):
return dict( return {field: row[idx] for idx, field in enumerate(fields)}
(field, row[idx])
for idx, field in enumerate(fields)
)
# #
# Depreated helpers # Depreated helpers
@ -135,9 +134,9 @@ class PgDB(DB):
@classmethod @classmethod
def _quote_value(cls, value): def _quote_value(cls, value):
""" Quote a value for SQL query """ """Quote a value for SQL query"""
if value is None: if value is None:
return 'NULL' return "NULL"
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
return str(value) return str(value)
@ -148,26 +147,26 @@ class PgDB(DB):
value = cls._format_date(value) value = cls._format_date(value)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
return "'{0}'".format(value.replace("'", "''")) return "'{}'".format(value.replace("'", "''"))
@classmethod @classmethod
def _format_datetime(cls, value): def _format_datetime(cls, value):
""" Format datetime object as string """ """Format datetime object as string"""
assert isinstance(value, datetime.datetime) assert isinstance(value, datetime.datetime)
return value.strftime(cls.datetime_format) return value.strftime(cls.datetime_format)
@classmethod @classmethod
def _format_date(cls, value): def _format_date(cls, value):
""" Format date object as string """ """Format date object as string"""
assert isinstance(value, (datetime.date, datetime.datetime)) assert isinstance(value, (datetime.date, datetime.datetime))
return value.strftime(cls.date_format) return value.strftime(cls.date_format)
@classmethod @classmethod
def time2datetime(cls, time): def time2datetime(cls, time):
""" Convert timestamp to datetime string """ """Convert timestamp to datetime string"""
return cls._format_datetime(datetime.datetime.fromtimestamp(int(time))) return cls._format_datetime(datetime.datetime.fromtimestamp(int(time)))
@classmethod @classmethod
def time2date(cls, time): def time2date(cls, time):
""" Convert timestamp to date string """ """Convert timestamp to date string"""
return cls._format_date(datetime.date.fromtimestamp(int(time))) return cls._format_date(datetime.date.fromtimestamp(int(time)))

View file

@ -1,28 +1,24 @@
# coding: utf8
""" Report """ """ Report """
import atexit import atexit
import logging import logging
from mylib.config import ConfigurableObject from mylib.config import ConfigurableObject, StringOption
from mylib.config import StringOption
from mylib.email import EmailClient from mylib.email import EmailClient
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
""" Logging report """ """Logging report"""
_config_name = 'report' _config_name = "report"
_config_comment = 'Email report' _config_comment = "Email report"
_defaults = { _defaults = {
'recipient': None, "recipient": None,
'subject': 'Report', "subject": "Report",
'loglevel': 'WARNING', "loglevel": "WARNING",
'logformat': '%(asctime)s - %(levelname)s - %(message)s', "logformat": "%(asctime)s - %(levelname)s - %(message)s",
} }
content = [] content = []
@ -40,20 +36,28 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
self.initialize() self.initialize()
def configure(self, **kwargs): # pylint: disable=arguments-differ def configure(self, **kwargs): # pylint: disable=arguments-differ
""" Configure options on registered mylib.Config object """ """Configure options on registered mylib.Config object"""
section = super().configure(**kwargs) section = super().configure(**kwargs)
section.add_option(StringOption, "recipient", comment="Report recipient email address")
section.add_option( section.add_option(
StringOption, 'recipient', comment='Report recipient email address') StringOption,
"subject",
default=self._defaults["subject"],
comment="Report email subject",
)
section.add_option( section.add_option(
StringOption, 'subject', default=self._defaults['subject'], StringOption,
comment='Report email subject') "loglevel",
default=self._defaults["loglevel"],
comment='Report log level (as accept by python logging, for instance "INFO")',
)
section.add_option( section.add_option(
StringOption, 'loglevel', default=self._defaults['loglevel'], StringOption,
comment='Report log level (as accept by python logging, for instance "INFO")') "logformat",
section.add_option( default=self._defaults["logformat"],
StringOption, 'logformat', default=self._defaults['logformat'], comment='Report log level (as accept by python logging, for instance "INFO")',
comment='Report log level (as accept by python logging, for instance "INFO")') )
if not self.email_client: if not self.email_client:
self.email_client = EmailClient(config=self._config) self.email_client = EmailClient(config=self._config)
@ -62,66 +66,70 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance
return section return section
def initialize(self, loaded_config=None): def initialize(self, loaded_config=None):
""" Configuration initialized hook """ """Configuration initialized hook"""
super().initialize(loaded_config=loaded_config) super().initialize(loaded_config=loaded_config)
self.handler = logging.StreamHandler(self) self.handler = logging.StreamHandler(self)
loglevel = self._get_option('loglevel').upper() loglevel = self._get_option("loglevel").upper()
assert hasattr(logging, loglevel), ( assert hasattr(logging, loglevel), f"Invalid report loglevel {loglevel}"
f'Invalid report loglevel {loglevel}')
self.handler.setLevel(getattr(logging, loglevel)) self.handler.setLevel(getattr(logging, loglevel))
self.formatter = logging.Formatter(self._get_option('logformat')) self.formatter = logging.Formatter(self._get_option("logformat"))
self.handler.setFormatter(self.formatter) self.handler.setFormatter(self.formatter)
def get_handler(self): def get_handler(self):
""" Retreive logging handler """ """Retreive logging handler"""
return self.handler return self.handler
def write(self, msg): def write(self, msg):
""" Write a message """ """Write a message"""
self.content.append(msg) self.content.append(msg)
def get_content(self): def get_content(self):
""" Read the report content """ """Read the report content"""
return "".join(self.content) return "".join(self.content)
def add_attachment_file(self, filepath): def add_attachment_file(self, filepath):
""" Add attachment file """ """Add attachment file"""
self._attachment_files.append(filepath) self._attachment_files.append(filepath)
def add_attachment_payload(self, payload): def add_attachment_payload(self, payload):
""" Add attachment payload """ """Add attachment payload"""
self._attachment_payloads.append(payload) self._attachment_payloads.append(payload)
def send(self, subject=None, rcpt_to=None, email_client=None, just_try=False): def send(self, subject=None, rcpt_to=None, email_client=None, just_try=False):
""" Send report using an EmailClient """ """Send report using an EmailClient"""
if rcpt_to is None: if rcpt_to is None:
rcpt_to = self._get_option('recipient') rcpt_to = self._get_option("recipient")
if not rcpt_to: if not rcpt_to:
log.debug('No report recipient, do not send report') log.debug("No report recipient, do not send report")
return True return True
if subject is None: if subject is None:
subject = self._get_option('subject') subject = self._get_option("subject")
assert subject, "You must provide report subject using Report.__init__ or Report.send" assert subject, "You must provide report subject using Report.__init__ or Report.send"
if email_client is None: if email_client is None:
email_client = self.email_client email_client = self.email_client
assert email_client, ( assert email_client, (
"You must provide an email client __init__(), send() or send_at_exit() methods argument email_client") "You must provide an email client __init__(), send() or send_at_exit() methods argument"
" email_client"
)
content = self.get_content() content = self.get_content()
if not content: if not content:
log.debug('Report is empty, do not send it') log.debug("Report is empty, do not send it")
return True return True
msg = email_client.forge_message( msg = email_client.forge_message(
rcpt_to, subject=subject, text_body=content, rcpt_to,
subject=subject,
text_body=content,
attachment_files=self._attachment_files, attachment_files=self._attachment_files,
attachment_payloads=self._attachment_payloads) attachment_payloads=self._attachment_payloads,
)
if email_client.send(rcpt_to, msg=msg, just_try=just_try): if email_client.send(rcpt_to, msg=msg, just_try=just_try):
log.debug('Report sent to %s', rcpt_to) log.debug("Report sent to %s", rcpt_to)
return True return True
log.error('Fail to send report to %s', rcpt_to) log.error("Fail to send report to %s", rcpt_to)
return False return False
def send_at_exit(self, **kwargs): def send_at_exit(self, **kwargs):
""" Send report at exit """ """Send report at exit"""
atexit.register(self.send, **kwargs) atexit.register(self.send, **kwargs)

View file

@ -1,22 +1,18 @@
# -*- coding: utf-8 -*-
""" Test Email client """ """ Test Email client """
import datetime import datetime
import getpass
import logging import logging
import sys import sys
import getpass
from mako.template import Template as MakoTemplate from mako.template import Template as MakoTemplate
from mylib.scripts.helpers import get_opts_parser, add_email_opts from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging
from mylib.scripts.helpers import init_logging, init_email_client
log = logging.getLogger("mylib.scripts.email_test")
log = logging.getLogger('mylib.scripts.email_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
""" Script main """ """Script main"""
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
@ -24,10 +20,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True) parser = get_opts_parser(just_try=True)
add_email_opts(parser) add_email_opts(parser)
test_opts = parser.add_argument_group('Test email options') test_opts = parser.add_argument_group("Test email options")
test_opts.add_argument( test_opts.add_argument(
'-t', '--to', "-t",
"--to",
action="store", action="store",
type=str, type=str,
dest="test_to", dest="test_to",
@ -35,7 +32,8 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
) )
test_opts.add_argument( test_opts.add_argument(
'-m', '--mako', "-m",
"--mako",
action="store_true", action="store_true",
dest="test_mako", dest="test_mako",
help="Test mako templating", help="Test mako templating",
@ -44,14 +42,14 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args() options = parser.parse_args()
if not options.test_to: if not options.test_to:
parser.error('You must specify test email recipient using -t/--to parameter') parser.error("You must specify test email recipient using -t/--to parameter")
sys.exit(1) sys.exit(1)
# Initialize logs # Initialize logs
init_logging(options, 'Test EmailClient') init_logging(options, "Test EmailClient")
if options.email_smtp_user and not options.email_smtp_password: if options.email_smtp_user and not options.email_smtp_password:
options.email_smtp_password = getpass.getpass('Please enter SMTP password: ') options.email_smtp_password = getpass.getpass("Please enter SMTP password: ")
email_client = init_email_client( email_client = init_email_client(
options, options,
@ -59,20 +57,24 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
test=dict( test=dict(
subject="Test email", subject="Test email",
text=( text=(
"Just a test email sent at {sent_date}." if not options.test_mako else "Just a test email sent at {sent_date}."
MakoTemplate("Just a test email sent at ${sent_date}.") if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
), ),
html=( html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else "<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>") if not options.test_mako
) else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
)
),
) )
) ),
) )
log.info('Send a test email to %s', options.test_to) log.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
log.info('Test email sent') log.info("Test email sent")
sys.exit(0) sys.exit(0)
log.error('Fail to send test email') log.error("Fail to send test email")
sys.exit(1) sys.exit(1)

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" Test Email client using mylib.config.Config for configuration """ """ Test Email client using mylib.config.Config for configuration """
import datetime import datetime
import logging import logging
@ -10,16 +8,15 @@ from mako.template import Template as MakoTemplate
from mylib.config import Config from mylib.config import Config
from mylib.email import EmailClient from mylib.email import EmailClient
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
""" Script main """ """Script main"""
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
config = Config(__doc__, __name__.replace('.', '_')) config = Config(__doc__, __name__.replace(".", "_"))
email_client = EmailClient(config=config) email_client = EmailClient(config=config)
email_client.configure() email_client.configure()
@ -27,10 +24,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
# Options parser # Options parser
parser = config.get_arguments_parser(description=__doc__) parser = config.get_arguments_parser(description=__doc__)
test_opts = parser.add_argument_group('Test email options') test_opts = parser.add_argument_group("Test email options")
test_opts.add_argument( test_opts.add_argument(
'-t', '--to', "-t",
"--to",
action="store", action="store",
type=str, type=str,
dest="test_to", dest="test_to",
@ -38,7 +36,8 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
) )
test_opts.add_argument( test_opts.add_argument(
'-m', '--mako', "-m",
"--mako",
action="store_true", action="store_true",
dest="test_mako", dest="test_mako",
help="Test mako templating", help="Test mako templating",
@ -47,26 +46,30 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = config.parse_arguments_options() options = config.parse_arguments_options()
if not options.test_to: if not options.test_to:
parser.error('You must specify test email recipient using -t/--to parameter') parser.error("You must specify test email recipient using -t/--to parameter")
sys.exit(1) sys.exit(1)
email_client.templates = dict( email_client.templates = dict(
test=dict( test=dict(
subject="Test email", subject="Test email",
text=( text=(
"Just a test email sent at {sent_date}." if not options.test_mako else "Just a test email sent at {sent_date}."
MakoTemplate("Just a test email sent at ${sent_date}.") if not options.test_mako
else MakoTemplate("Just a test email sent at ${sent_date}.")
), ),
html=( html=(
"<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>" if not options.test_mako else "<strong>Just a test email.</strong> <small>(sent at {sent_date})</small>"
MakoTemplate("<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>") if not options.test_mako
) else MakoTemplate(
"<strong>Just a test email.</strong> <small>(sent at ${sent_date})</small>"
)
),
) )
) )
logging.info('Send a test email to %s', options.test_to) logging.info("Send a test email to %s", options.test_to)
if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()):
logging.info('Test email sent') logging.info("Test email sent")
sys.exit(0) sys.exit(0)
logging.error('Fail to send test email') logging.error("Fail to send test email")
sys.exit(1) sys.exit(1)

View file

@ -1,20 +1,18 @@
# coding: utf8
""" Scripts helpers """ """ Scripts helpers """
import argparse import argparse
import getpass import getpass
import logging import logging
import os.path
import socket import socket
import sys import sys
import os.path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def init_logging(options, name, report=None): def init_logging(options, name, report=None):
""" Initialize logging from calling script options """ """Initialize logging from calling script options"""
logformat = f'%(asctime)s - {name} - %(levelname)s - %(message)s' logformat = f"%(asctime)s - {name} - %(levelname)s - %(message)s"
if options.debug: if options.debug:
loglevel = logging.DEBUG loglevel = logging.DEBUG
elif options.verbose: elif options.verbose:
@ -33,193 +31,201 @@ def init_logging(options, name, report=None):
def get_default_opt_value(config, default_config, key): def get_default_opt_value(config, default_config, key):
""" Retreive default option value from config or default config dictionaries """ """Retreive default option value from config or default config dictionaries"""
if config and key in config: if config and key in config:
return config[key] return config[key]
return default_config.get(key) return default_config.get(key)
def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, config=None): def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, config=None):
""" Retrieve options parser """ """Retrieve options parser"""
default_config = dict(logfile=None) default_config = dict(logfile=None)
parser = argparse.ArgumentParser(description=desc) parser = argparse.ArgumentParser(description=desc)
parser.add_argument( parser.add_argument(
'-v', '--verbose', "-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode"
action="store_true",
dest="verbose",
help="Enable verbose mode"
) )
parser.add_argument( parser.add_argument(
'-d', '--debug', "-d", "--debug", action="store_true", dest="debug", help="Enable debug mode"
action="store_true",
dest="debug",
help="Enable debug mode"
) )
parser.add_argument( parser.add_argument(
'-l', '--log-file', "-l",
"--log-file",
action="store", action="store",
type=str, type=str,
dest="logfile", dest="logfile",
help=( help=f'Log file path (default: {get_default_opt_value(config, default_config, "logfile")})',
'Log file path (default: ' default=get_default_opt_value(config, default_config, "logfile"),
f'{get_default_opt_value(config, default_config, "logfile")})'),
default=get_default_opt_value(config, default_config, 'logfile')
) )
parser.add_argument( parser.add_argument(
'-C', '--console', "-C",
"--console",
action="store_true", action="store_true",
dest="console", dest="console",
help="Always log on console (even if log file is configured)" help="Always log on console (even if log file is configured)",
) )
if just_try: if just_try:
parser.add_argument( parser.add_argument(
'-j', '--just-try', "-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode"
action="store_true",
dest="just_try",
help="Enable just-try mode"
) )
if just_one: if just_one:
parser.add_argument( parser.add_argument(
'-J', '--just-one', "-J", "--just-one", action="store_true", dest="just_one", help="Enable just-one mode"
action="store_true",
dest="just_one",
help="Enable just-one mode"
) )
if progress: if progress:
parser.add_argument( parser.add_argument(
'-p', '--progress', "-p", "--progress", action="store_true", dest="progress", help="Enable progress bar"
action="store_true",
dest="progress",
help="Enable progress bar"
) )
return parser return parser
def add_email_opts(parser, config=None): def add_email_opts(parser, config=None):
""" Add email options """ """Add email options"""
email_opts = parser.add_argument_group('Email options') email_opts = parser.add_argument_group("Email options")
default_config = dict( default_config = dict(
smtp_host="127.0.0.1", smtp_port=25, smtp_ssl=False, smtp_tls=False, smtp_user=None, smtp_host="127.0.0.1",
smtp_password=None, smtp_debug=False, email_encoding=sys.getdefaultencoding(), smtp_port=25,
sender_name=getpass.getuser(), sender_email=f'{getpass.getuser()}@{socket.gethostname()}', smtp_ssl=False,
catch_all=False smtp_tls=False,
smtp_user=None,
smtp_password=None,
smtp_debug=False,
email_encoding=sys.getdefaultencoding(),
sender_name=getpass.getuser(),
sender_email=f"{getpass.getuser()}@{socket.gethostname()}",
catch_all=False,
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-host', "--smtp-host",
action="store", action="store",
type=str, type=str,
dest="email_smtp_host", dest="email_smtp_host",
help=( help=f'SMTP host (default: {get_default_opt_value(config, default_config, "smtp_host")})',
'SMTP host (default: ' default=get_default_opt_value(config, default_config, "smtp_host"),
f'{get_default_opt_value(config, default_config, "smtp_host")})'),
default=get_default_opt_value(config, default_config, 'smtp_host')
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-port', "--smtp-port",
action="store", action="store",
type=int, type=int,
dest="email_smtp_port", dest="email_smtp_port",
help=f'SMTP port (default: {get_default_opt_value(config, default_config, "smtp_port")})', help=f'SMTP port (default: {get_default_opt_value(config, default_config, "smtp_port")})',
default=get_default_opt_value(config, default_config, 'smtp_port') default=get_default_opt_value(config, default_config, "smtp_port"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-ssl', "--smtp-ssl",
action="store_true", action="store_true",
dest="email_smtp_ssl", dest="email_smtp_ssl",
help=f'Use SSL (default: {get_default_opt_value(config, default_config, "smtp_ssl")})', help=f'Use SSL (default: {get_default_opt_value(config, default_config, "smtp_ssl")})',
default=get_default_opt_value(config, default_config, 'smtp_ssl') default=get_default_opt_value(config, default_config, "smtp_ssl"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-tls', "--smtp-tls",
action="store_true", action="store_true",
dest="email_smtp_tls", dest="email_smtp_tls",
help=f'Use TLS (default: {get_default_opt_value(config, default_config, "smtp_tls")})', help=f'Use TLS (default: {get_default_opt_value(config, default_config, "smtp_tls")})',
default=get_default_opt_value(config, default_config, 'smtp_tls') default=get_default_opt_value(config, default_config, "smtp_tls"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-user', "--smtp-user",
action="store", action="store",
type=str, type=str,
dest="email_smtp_user", dest="email_smtp_user",
help=f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})', help=(
default=get_default_opt_value(config, default_config, 'smtp_user') f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})'
),
default=get_default_opt_value(config, default_config, "smtp_user"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-password', "--smtp-password",
action="store", action="store",
type=str, type=str,
dest="email_smtp_password", dest="email_smtp_password",
help=f'SMTP password (default: {get_default_opt_value(config, default_config, "smtp_password")})', help=(
default=get_default_opt_value(config, default_config, 'smtp_password') "SMTP password (default:"
f' {get_default_opt_value(config, default_config, "smtp_password")})'
),
default=get_default_opt_value(config, default_config, "smtp_password"),
) )
email_opts.add_argument( email_opts.add_argument(
'--smtp-debug', "--smtp-debug",
action="store_true", action="store_true",
dest="email_smtp_debug", dest="email_smtp_debug",
help=f'Debug SMTP connection (default: {get_default_opt_value(config, default_config, "smtp_debug")})', help=(
default=get_default_opt_value(config, default_config, 'smtp_debug') "Debug SMTP connection (default:"
f' {get_default_opt_value(config, default_config, "smtp_debug")})'
),
default=get_default_opt_value(config, default_config, "smtp_debug"),
) )
email_opts.add_argument( email_opts.add_argument(
'--email-encoding', "--email-encoding",
action="store", action="store",
type=str, type=str,
dest="email_encoding", dest="email_encoding",
help=f'SMTP encoding (default: {get_default_opt_value(config, default_config, "email_encoding")})', help=(
default=get_default_opt_value(config, default_config, 'email_encoding') "SMTP encoding (default:"
f' {get_default_opt_value(config, default_config, "email_encoding")})'
),
default=get_default_opt_value(config, default_config, "email_encoding"),
) )
email_opts.add_argument( email_opts.add_argument(
'--sender-name', "--sender-name",
action="store", action="store",
type=str, type=str,
dest="email_sender_name", dest="email_sender_name",
help=f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})', help=(
default=get_default_opt_value(config, default_config, 'sender_name') f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})'
),
default=get_default_opt_value(config, default_config, "sender_name"),
) )
email_opts.add_argument( email_opts.add_argument(
'--sender-email', "--sender-email",
action="store", action="store",
type=str, type=str,
dest="email_sender_email", dest="email_sender_email",
help=f'Sender email (default: {get_default_opt_value(config, default_config, "sender_email")})', help=(
default=get_default_opt_value(config, default_config, 'sender_email') "Sender email (default:"
f' {get_default_opt_value(config, default_config, "sender_email")})'
),
default=get_default_opt_value(config, default_config, "sender_email"),
) )
email_opts.add_argument( email_opts.add_argument(
'--catch-all', "--catch-all",
action="store", action="store",
type=str, type=str,
dest="email_catch_all", dest="email_catch_all",
help=( help=(
'Catch all sent email: specify catch recipient email address ' "Catch all sent email: specify catch recipient email address "
f'(default: {get_default_opt_value(config, default_config, "catch_all")})'), f'(default: {get_default_opt_value(config, default_config, "catch_all")})'
default=get_default_opt_value(config, default_config, 'catch_all') ),
default=get_default_opt_value(config, default_config, "catch_all"),
) )
def init_email_client(options, **kwargs): def init_email_client(options, **kwargs):
""" Initialize email client from calling script options """ """Initialize email client from calling script options"""
from mylib.email import EmailClient # pylint: disable=import-outside-toplevel from mylib.email import EmailClient # pylint: disable=import-outside-toplevel
log.info('Initialize Email client')
log.info("Initialize Email client")
return EmailClient( return EmailClient(
smtp_host=options.email_smtp_host, smtp_host=options.email_smtp_host,
smtp_port=options.email_smtp_port, smtp_port=options.email_smtp_port,
@ -231,64 +237,62 @@ def init_email_client(options, **kwargs):
sender_name=options.email_sender_name, sender_name=options.email_sender_name,
sender_email=options.email_sender_email, sender_email=options.email_sender_email,
catch_all_addr=options.email_catch_all, catch_all_addr=options.email_catch_all,
just_try=options.just_try if hasattr(options, 'just_try') else False, just_try=options.just_try if hasattr(options, "just_try") else False,
encoding=options.email_encoding, encoding=options.email_encoding,
**kwargs **kwargs,
) )
def add_sftp_opts(parser): def add_sftp_opts(parser):
""" Add SFTP options to argpase.ArgumentParser """ """Add SFTP options to argpase.ArgumentParser"""
sftp_opts = parser.add_argument_group("SFTP options") sftp_opts = parser.add_argument_group("SFTP options")
sftp_opts.add_argument( sftp_opts.add_argument(
'-H', '--sftp-host', "-H",
"--sftp-host",
action="store", action="store",
type=str, type=str,
dest="sftp_host", dest="sftp_host",
help="SFTP Host (default: localhost)", help="SFTP Host (default: localhost)",
default='localhost' default="localhost",
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'--sftp-port', "--sftp-port",
action="store", action="store",
type=int, type=int,
dest="sftp_port", dest="sftp_port",
help="SFTP Port (default: 22)", help="SFTP Port (default: 22)",
default=22 default=22,
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'-u', '--sftp-user', "-u", "--sftp-user", action="store", type=str, dest="sftp_user", help="SFTP User"
action="store",
type=str,
dest="sftp_user",
help="SFTP User"
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'-P', '--sftp-password', "-P",
"--sftp-password",
action="store", action="store",
type=str, type=str,
dest="sftp_password", dest="sftp_password",
help="SFTP Password" help="SFTP Password",
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'--sftp-known-hosts', "--sftp-known-hosts",
action="store", action="store",
type=str, type=str,
dest="sftp_known_hosts", dest="sftp_known_hosts",
help="SFTP known_hosts file path (default: ~/.ssh/known_hosts)", help="SFTP known_hosts file path (default: ~/.ssh/known_hosts)",
default=os.path.expanduser('~/.ssh/known_hosts') default=os.path.expanduser("~/.ssh/known_hosts"),
) )
sftp_opts.add_argument( sftp_opts.add_argument(
'--sftp-auto-add-unknown-host-key', "--sftp-auto-add-unknown-host-key",
action="store_true", action="store_true",
dest="sftp_auto_add_unknown_host_key", dest="sftp_auto_add_unknown_host_key",
help="Auto-add unknown SSH host key" help="Auto-add unknown SSH host key",
) )
return sftp_opts return sftp_opts

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
""" Test LDAP """ """ Test LDAP """
import datetime import datetime
import logging import logging
@ -8,16 +6,14 @@ import sys
import dateutil.tz import dateutil.tz
import pytz import pytz
from mylib.ldap import format_datetime, format_date, parse_datetime, parse_date from mylib.ldap import format_date, format_datetime, parse_date, parse_datetime
from mylib.scripts.helpers import get_opts_parser from mylib.scripts.helpers import get_opts_parser, init_logging
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.ldap_test")
log = logging.getLogger('mylib.scripts.ldap_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
""" Script main """ """Script main"""
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
@ -26,52 +22,121 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args() options = parser.parse_args()
# Initialize logs # Initialize logs
init_logging(options, 'Test LDAP helpers') init_logging(options, "Test LDAP helpers")
now = datetime.datetime.now().replace(tzinfo=dateutil.tz.tzlocal()) now = datetime.datetime.now().replace(tzinfo=dateutil.tz.tzlocal())
print(f'Now = {now}') print(f"Now = {now}")
datestring_now = format_datetime(now) datestring_now = format_datetime(now)
print(f'format_datetime : {datestring_now}') print(f"format_datetime : {datestring_now}")
print(f'format_datetime (from_timezone=utc) : {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}') print(
print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}') "format_datetime (from_timezone=utc) :"
print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone="local")}') f" {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}"
print(f'format_datetime (from_timezone=Paris) : {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}') )
print(f'format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}') print(
print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}') "format_datetime (from_timezone=local) :"
f" {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}"
)
print(
"format_datetime (from_timezone=local) :"
f' {format_datetime(now.replace(tzinfo=None), from_timezone="local")}'
)
print(
"format_datetime (from_timezone=Paris) :"
f' {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}'
)
print(f"format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}")
print(
"format_datetime (to_timezone=local) :"
f" {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone="local")}') print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone="local")}')
print(f'format_datetime (to_timezone=Tokyo) : {format_datetime(now, to_timezone="Asia/Tokyo")}') print(f'format_datetime (to_timezone=Tokyo) : {format_datetime(now, to_timezone="Asia/Tokyo")}')
print(f'format_datetime (naive=True) : {format_datetime(now, naive=True)}') print(f"format_datetime (naive=True) : {format_datetime(now, naive=True)}")
print(f'format_date : {format_date(now)}') print(f"format_date : {format_date(now)}")
print(f'format_date (from_timezone=utc) : {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}') print(
print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}') "format_date (from_timezone=utc) :"
print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone="local")}') f" {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}"
print(f'format_date (from_timezone=Paris) : {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}') )
print(f'format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}') print(
print(f'format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}') "format_date (from_timezone=local) :"
f" {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}"
)
print(
"format_date (from_timezone=local) :"
f' {format_date(now.replace(tzinfo=None), from_timezone="local")}'
)
print(
"format_date (from_timezone=Paris) :"
f' {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}'
)
print(f"format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}")
print(
f"format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'format_date (to_timezone=local) : {format_date(now, to_timezone="local")}') print(f'format_date (to_timezone=local) : {format_date(now, to_timezone="local")}')
print(f'format_date (to_timezone=Tokyo) : {format_date(now, to_timezone="Asia/Tokyo")}') print(f'format_date (to_timezone=Tokyo) : {format_date(now, to_timezone="Asia/Tokyo")}')
print(f'format_date (naive=True) : {format_date(now, naive=True)}') print(f"format_date (naive=True) : {format_date(now, naive=True)}")
print(f'parse_datetime : {parse_datetime(datestring_now)}') print(f"parse_datetime : {parse_datetime(datestring_now)}")
print(f'parse_datetime (default_timezone=utc) : {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}') print(
print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}') "parse_datetime (default_timezone=utc) :"
print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone="local")}') f" {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}"
print(f'parse_datetime (default_timezone=Paris) : {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}') )
print(f'parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}') print(
print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}') "parse_datetime (default_timezone=local) :"
print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone="local")}') f" {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}"
print(f'parse_datetime (to_timezone=Tokyo) : {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}') )
print(f'parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}') print(
"parse_datetime (default_timezone=local) :"
f' {parse_datetime(datestring_now[0:-1], default_timezone="local")}'
)
print(
"parse_datetime (default_timezone=Paris) :"
f' {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}'
)
print(
f"parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}"
)
print(
"parse_datetime (to_timezone=local) :"
f" {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_datetime (to_timezone=local) :"
f' {parse_datetime(datestring_now, to_timezone="local")}'
)
print(
"parse_datetime (to_timezone=Tokyo) :"
f' {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}'
)
print(f"parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}")
print(f'parse_date : {parse_date(datestring_now)}') print(f"parse_date : {parse_date(datestring_now)}")
print(f'parse_date (default_timezone=utc) : {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}') print(
print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}') "parse_date (default_timezone=utc) :"
print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone="local")}') f" {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}"
print(f'parse_date (default_timezone=Paris) : {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}') )
print(f'parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}') print(
print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}') "parse_date (default_timezone=local) :"
f" {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}"
)
print(
"parse_date (default_timezone=local) :"
f' {parse_date(datestring_now[0:-1], default_timezone="local")}'
)
print(
"parse_date (default_timezone=Paris) :"
f' {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}'
)
print(f"parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}")
print(
"parse_date (to_timezone=local) :"
f" {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}"
)
print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone="local")}') print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone="local")}')
print(f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}') print(
print(f'parse_date (naive=True) : {parse_date(datestring_now, naive=True)}') f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}'
)
print(f"parse_date (naive=True) : {parse_date(datestring_now, naive=True)}")

View file

@ -64,6 +64,6 @@ def main(argv=None):
"mail": {"order": 12, "key": "email", "convert": lambda x: x.lower().strip()}, "mail": {"order": 12, "key": "email", "convert": lambda x: x.lower().strip()},
} }
print('Mapping source:\n' + pretty_format_value(src)) print("Mapping source:\n" + pretty_format_value(src))
print('Mapping config:\n' + pretty_format_value(map_c)) print("Mapping config:\n" + pretty_format_value(map_c))
print('Mapping result:\n' + pretty_format_value(map_hash(map_c, src))) print("Mapping result:\n" + pretty_format_value(map_hash(map_c, src)))

View file

@ -1,20 +1,16 @@
# -*- coding: utf-8 -*-
""" Test Progress bar """ """ Test Progress bar """
import logging import logging
import time
import sys import sys
import time
from mylib.pbar import Pbar from mylib.pbar import Pbar
from mylib.scripts.helpers import get_opts_parser from mylib.scripts.helpers import get_opts_parser, init_logging
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.pbar_test")
log = logging.getLogger('mylib.scripts.pbar_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
""" Script main """ """Script main"""
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
@ -23,20 +19,21 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(progress=True) parser = get_opts_parser(progress=True)
parser.add_argument( parser.add_argument(
'-c', '--count', "-c",
"--count",
action="store", action="store",
type=int, type=int,
dest="count", dest="count",
help=f'Progress bar max value (default: {default_max_val})', help=f"Progress bar max value (default: {default_max_val})",
default=default_max_val default=default_max_val,
) )
options = parser.parse_args() options = parser.parse_args()
# Initialize logs # Initialize logs
init_logging(options, 'Test Pbar') init_logging(options, "Test Pbar")
pbar = Pbar('Test', options.count, enabled=options.progress) pbar = Pbar("Test", options.count, enabled=options.progress)
for idx in range(0, options.count): # pylint: disable=unused-variable for idx in range(0, options.count): # pylint: disable=unused-variable
pbar.increment() pbar.increment()

View file

@ -1,19 +1,15 @@
# -*- coding: utf-8 -*-
""" Test report """ """ Test report """
import logging import logging
import sys import sys
from mylib.report import Report from mylib.report import Report
from mylib.scripts.helpers import get_opts_parser, add_email_opts from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging
from mylib.scripts.helpers import init_logging, init_email_client
log = logging.getLogger("mylib.scripts.report_test")
log = logging.getLogger('mylib.scripts.report_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
""" Script main """ """Script main"""
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
@ -21,14 +17,10 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True) parser = get_opts_parser(just_try=True)
add_email_opts(parser) add_email_opts(parser)
report_opts = parser.add_argument_group('Report options') report_opts = parser.add_argument_group("Report options")
report_opts.add_argument( report_opts.add_argument(
'-t', '--to', "-t", "--to", action="store", type=str, dest="report_rcpt", help="Send report to this email"
action="store",
type=str,
dest="report_rcpt",
help="Send report to this email"
) )
options = parser.parse_args() options = parser.parse_args()
@ -37,13 +29,13 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser.error("You must specify a report recipient using -t/--to parameter") parser.error("You must specify a report recipient using -t/--to parameter")
# Initialize logs # Initialize logs
report = Report(rcpt_to=options.report_rcpt, subject='Test report') report = Report(rcpt_to=options.report_rcpt, subject="Test report")
init_logging(options, 'Test Report', report=report) init_logging(options, "Test Report", report=report)
email_client = init_email_client(options) email_client = init_email_client(options)
report.send_at_exit(email_client=email_client) report.send_at_exit(email_client=email_client)
logging.debug('Test debug message') logging.debug("Test debug message")
logging.info('Test info message') logging.info("Test info message")
logging.warning('Test warning message') logging.warning("Test warning message")
logging.error('Test error message') logging.error("Test error message")

View file

@ -1,26 +1,21 @@
# -*- coding: utf-8 -*-
""" Test SFTP client """ """ Test SFTP client """
import atexit import atexit
import tempfile import getpass
import logging import logging
import sys
import os import os
import random import random
import string import string
import sys
import tempfile
import getpass from mylib.scripts.helpers import add_sftp_opts, get_opts_parser, init_logging
from mylib.sftp import SFTPClient from mylib.sftp import SFTPClient
from mylib.scripts.helpers import get_opts_parser, add_sftp_opts
from mylib.scripts.helpers import init_logging
log = logging.getLogger("mylib.scripts.sftp_test")
log = logging.getLogger('mylib.scripts.sftp_test')
def main(argv=None): # pylint: disable=too-many-locals,too-many-statements def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
""" Script main """ """Script main"""
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
@ -28,10 +23,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
parser = get_opts_parser(just_try=True) parser = get_opts_parser(just_try=True)
add_sftp_opts(parser) add_sftp_opts(parser)
test_opts = parser.add_argument_group('Test SFTP options') test_opts = parser.add_argument_group("Test SFTP options")
test_opts.add_argument( test_opts.add_argument(
'-p', '--remote-upload-path', "-p",
"--remote-upload-path",
action="store", action="store",
type=str, type=str,
dest="upload_path", dest="upload_path",
@ -41,66 +37,68 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements
options = parser.parse_args() options = parser.parse_args()
# Initialize logs # Initialize logs
init_logging(options, 'Test SFTP client') init_logging(options, "Test SFTP client")
if options.sftp_user and not options.sftp_password: if options.sftp_user and not options.sftp_password:
options.sftp_password = getpass.getpass('Please enter SFTP password: ') options.sftp_password = getpass.getpass("Please enter SFTP password: ")
log.info('Initialize Email client') log.info("Initialize Email client")
sftp = SFTPClient(options=options, just_try=options.just_try) sftp = SFTPClient(options=options, just_try=options.just_try)
sftp.connect() sftp.connect()
atexit.register(sftp.close) atexit.register(sftp.close)
log.debug('Create tempory file') log.debug("Create tempory file")
test_content = b'Juste un test.' test_content = b"Juste un test."
tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
tmp_file = os.path.join( tmp_file = os.path.join(
tmp_dir.name, tmp_dir.name, f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}'
f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}'
) )
log.debug('Temporary file path: "%s"', tmp_file) log.debug('Temporary file path: "%s"', tmp_file)
with open(tmp_file, 'wb') as file_desc: with open(tmp_file, "wb") as file_desc:
file_desc.write(test_content) file_desc.write(test_content)
log.debug( log.debug(
'Upload file %s to SFTP server (in %s)', tmp_file, "Upload file %s to SFTP server (in %s)",
options.upload_path if options.upload_path else "remote initial connection directory") tmp_file,
options.upload_path if options.upload_path else "remote initial connection directory",
)
if not sftp.upload_file(tmp_file, options.upload_path): if not sftp.upload_file(tmp_file, options.upload_path):
log.error('Fail to upload test file on SFTP server') log.error("Fail to upload test file on SFTP server")
sys.exit(1) sys.exit(1)
log.info('Test file uploaded on SFTP server') log.info("Test file uploaded on SFTP server")
remote_filepath = ( remote_filepath = (
os.path.join(options.upload_path, os.path.basename(tmp_file)) os.path.join(options.upload_path, os.path.basename(tmp_file))
if options.upload_path else os.path.basename(tmp_file) if options.upload_path
else os.path.basename(tmp_file)
) )
with tempfile.NamedTemporaryFile() as tmp_file2: with tempfile.NamedTemporaryFile() as tmp_file2:
log.info('Retrieve test file to %s', tmp_file2.name) log.info("Retrieve test file to %s", tmp_file2.name)
if not sftp.get_file(remote_filepath, tmp_file2.name): if not sftp.get_file(remote_filepath, tmp_file2.name):
log.error('Fail to retrieve test file') log.error("Fail to retrieve test file")
else: else:
with open(tmp_file2.name, 'rb') as file_desc: with open(tmp_file2.name, "rb") as file_desc:
content = file_desc.read() content = file_desc.read()
log.debug('Read content: %s', content) log.debug("Read content: %s", content)
if test_content == content: if test_content == content:
log.info('Content file retrieved match with uploaded one') log.info("Content file retrieved match with uploaded one")
else: else:
log.error('Content file retrieved doest not match with uploaded one') log.error("Content file retrieved doest not match with uploaded one")
try: try:
log.info('Remotly open test file %s', remote_filepath) log.info("Remotly open test file %s", remote_filepath)
file_desc = sftp.open_file(remote_filepath) file_desc = sftp.open_file(remote_filepath)
content = file_desc.read() content = file_desc.read()
log.debug('Read content: %s', content) log.debug("Read content: %s", content)
if test_content == content: if test_content == content:
log.info('Content of remote file match with uploaded one') log.info("Content of remote file match with uploaded one")
else: else:
log.error('Content of remote file doest not match with uploaded one') log.error("Content of remote file doest not match with uploaded one")
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
log.exception('An exception occurred remotly opening test file %s', remote_filepath) log.exception("An exception occurred remotly opening test file %s", remote_filepath)
if sftp.remove_file(remote_filepath): if sftp.remove_file(remote_filepath):
log.info('Test file removed on SFTP server') log.info("Test file removed on SFTP server")
else: else:
log.error('Fail to remove test file on SFTP server') log.error("Fail to remove test file on SFTP server")

View file

@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
""" SFTP client """ """ SFTP client """
import logging import logging
import os import os
from paramiko import SSHClient, AutoAddPolicy, SFTPAttributes from paramiko import AutoAddPolicy, SFTPAttributes, SSHClient
from mylib.config import ConfigurableObject from mylib.config import (
from mylib.config import BooleanOption BooleanOption,
from mylib.config import IntegerOption ConfigurableObject,
from mylib.config import PasswordOption IntegerOption,
from mylib.config import StringOption PasswordOption,
StringOption,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -23,16 +23,16 @@ class SFTPClient(ConfigurableObject):
This class abstract all interactions with the SFTP server. This class abstract all interactions with the SFTP server.
""" """
_config_name = 'sftp' _config_name = "sftp"
_config_comment = 'SFTP' _config_comment = "SFTP"
_defaults = { _defaults = {
'host': 'localhost', "host": "localhost",
'port': 22, "port": 22,
'user': None, "user": None,
'password': None, "password": None,
'known_hosts': os.path.expanduser('~/.ssh/known_hosts'), "known_hosts": os.path.expanduser("~/.ssh/known_hosts"),
'auto_add_unknown_host_key': False, "auto_add_unknown_host_key": False,
'just_try': False, "just_try": False,
} }
ssh_client = None ssh_client = None
@ -41,58 +41,77 @@ class SFTPClient(ConfigurableObject):
# pylint: disable=arguments-differ,arguments-renamed # pylint: disable=arguments-differ,arguments-renamed
def configure(self, just_try=True, **kwargs): def configure(self, just_try=True, **kwargs):
""" Configure options on registered mylib.Config object """ """Configure options on registered mylib.Config object"""
section = super().configure(**kwargs) section = super().configure(**kwargs)
section.add_option( section.add_option(
StringOption, 'host', default=self._defaults['host'], StringOption,
comment='SFTP server hostname/IP address') "host",
default=self._defaults["host"],
comment="SFTP server hostname/IP address",
)
section.add_option( section.add_option(
IntegerOption, 'port', default=self._defaults['port'], IntegerOption, "port", default=self._defaults["port"], comment="SFTP server port"
comment='SFTP server port') )
section.add_option( section.add_option(
StringOption, 'user', default=self._defaults['user'], StringOption,
comment='SFTP authentication username') "user",
default=self._defaults["user"],
comment="SFTP authentication username",
)
section.add_option( section.add_option(
PasswordOption, 'password', default=self._defaults['password'], PasswordOption,
"password",
default=self._defaults["password"],
comment='SFTP authentication password (set to "keyring" to use XDG keyring)', comment='SFTP authentication password (set to "keyring" to use XDG keyring)',
username_option='user', keyring_value='keyring') username_option="user",
keyring_value="keyring",
)
section.add_option( section.add_option(
StringOption, 'known_hosts', default=self._defaults['known_hosts'], StringOption,
comment='SFTP known_hosts filepath') "known_hosts",
default=self._defaults["known_hosts"],
comment="SFTP known_hosts filepath",
)
section.add_option( section.add_option(
BooleanOption, 'auto_add_unknown_host_key', BooleanOption,
default=self._defaults['auto_add_unknown_host_key'], "auto_add_unknown_host_key",
comment='Auto add unknown host key') default=self._defaults["auto_add_unknown_host_key"],
comment="Auto add unknown host key",
)
if just_try: if just_try:
section.add_option( section.add_option(
BooleanOption, 'just_try', default=self._defaults['just_try'], BooleanOption,
comment='Just-try mode: do not really make change on remote SFTP host') "just_try",
default=self._defaults["just_try"],
comment="Just-try mode: do not really make change on remote SFTP host",
)
return section return section
def initialize(self, loaded_config=None): def initialize(self, loaded_config=None):
""" Configuration initialized hook """ """Configuration initialized hook"""
super().__init__(loaded_config=loaded_config) super().__init__(loaded_config=loaded_config)
def connect(self): def connect(self):
""" Connect to SFTP server """ """Connect to SFTP server"""
if self.ssh_client: if self.ssh_client:
return return
host = self._get_option('host') host = self._get_option("host")
port = self._get_option('port') port = self._get_option("port")
log.info("Connect to SFTP server %s:%d", host, port) log.info("Connect to SFTP server %s:%d", host, port)
self.ssh_client = SSHClient() self.ssh_client = SSHClient()
if self._get_option('known_hosts'): if self._get_option("known_hosts"):
self.ssh_client.load_host_keys(self._get_option('known_hosts')) self.ssh_client.load_host_keys(self._get_option("known_hosts"))
if self._get_option('auto_add_unknown_host_key'): if self._get_option("auto_add_unknown_host_key"):
log.debug('Set missing host key policy to auto-add') log.debug("Set missing host key policy to auto-add")
self.ssh_client.set_missing_host_key_policy(AutoAddPolicy()) self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
self.ssh_client.connect( self.ssh_client.connect(
host, port=port, host,
username=self._get_option('user'), port=port,
password=self._get_option('password') username=self._get_option("user"),
password=self._get_option("password"),
) )
self.sftp_client = self.ssh_client.open_sftp() self.sftp_client = self.ssh_client.open_sftp()
self.initial_directory = self.sftp_client.getcwd() self.initial_directory = self.sftp_client.getcwd()
@ -103,43 +122,43 @@ class SFTPClient(ConfigurableObject):
self.initial_directory = "" self.initial_directory = ""
def get_file(self, remote_filepath, local_filepath): def get_file(self, remote_filepath, local_filepath):
""" Retrieve a file from SFTP server """ """Retrieve a file from SFTP server"""
self.connect() self.connect()
log.debug("Retreive file '%s' to '%s'", remote_filepath, local_filepath) log.debug("Retreive file '%s' to '%s'", remote_filepath, local_filepath)
return self.sftp_client.get(remote_filepath, local_filepath) is None return self.sftp_client.get(remote_filepath, local_filepath) is None
def open_file(self, remote_filepath, mode='r'): def open_file(self, remote_filepath, mode="r"):
""" Remotly open a file on SFTP server """ """Remotly open a file on SFTP server"""
self.connect() self.connect()
log.debug("Remotly open file '%s'", remote_filepath) log.debug("Remotly open file '%s'", remote_filepath)
return self.sftp_client.open(remote_filepath, mode=mode) return self.sftp_client.open(remote_filepath, mode=mode)
def upload_file(self, filepath, remote_directory=None): def upload_file(self, filepath, remote_directory=None):
""" Upload a file on SFTP server """ """Upload a file on SFTP server"""
self.connect() self.connect()
remote_filepath = os.path.join( remote_filepath = os.path.join(
remote_directory if remote_directory else self.initial_directory, remote_directory if remote_directory else self.initial_directory,
os.path.basename(filepath) os.path.basename(filepath),
) )
log.debug("Upload file '%s' to '%s'", filepath, remote_filepath) log.debug("Upload file '%s' to '%s'", filepath, remote_filepath)
if self._get_option('just_try'): if self._get_option("just_try"):
log.debug( log.debug(
"Just-try mode: do not really upload file '%s' to '%s'", "Just-try mode: do not really upload file '%s' to '%s'", filepath, remote_filepath
filepath, remote_filepath) )
return True return True
result = self.sftp_client.put(filepath, remote_filepath) result = self.sftp_client.put(filepath, remote_filepath)
return isinstance(result, SFTPAttributes) return isinstance(result, SFTPAttributes)
def remove_file(self, filepath): def remove_file(self, filepath):
""" Remove a file on SFTP server """ """Remove a file on SFTP server"""
self.connect() self.connect()
log.debug("Remove file '%s'", filepath) log.debug("Remove file '%s'", filepath)
if self._get_option('just_try'): if self._get_option("just_try"):
log.debug("Just - try mode: do not really remove file '%s'", filepath) log.debug("Just - try mode: do not really remove file '%s'", filepath)
return True return True
return self.sftp_client.remove(filepath) is None return self.sftp_client.remove(filepath) is None
def close(self): def close(self):
""" Close SSH/SFTP connection """ """Close SSH/SFTP connection"""
log.debug("Close connection") log.debug("Close connection")
self.ssh_client.close() self.ssh_client.close()

View file

@ -8,45 +8,43 @@ log = logging.getLogger(__name__)
class TelltaleFile: class TelltaleFile:
""" Telltale file helper class """ """Telltale file helper class"""
def __init__(self, filepath=None, filename=None, dirpath=None): def __init__(self, filepath=None, filename=None, dirpath=None):
assert filepath or filename, "filename or filepath is required" assert filepath or filename, "filename or filepath is required"
if filepath: if filepath:
assert not filename or os.path.basename(filepath) == filename, "filepath and filename does not match" assert (
assert not dirpath or os.path.dirname(filepath) == dirpath, "filepath and dirpath does not match" not filename or os.path.basename(filepath) == filename
), "filepath and filename does not match"
assert (
not dirpath or os.path.dirname(filepath) == dirpath
), "filepath and dirpath does not match"
self.filename = filename if filename else os.path.basename(filepath) self.filename = filename if filename else os.path.basename(filepath)
self.dirpath = ( self.dirpath = (
dirpath if dirpath dirpath if dirpath else (os.path.dirname(filepath) if filepath else os.getcwd())
else (
os.path.dirname(filepath) if filepath
else os.getcwd()
)
) )
self.filepath = filepath if filepath else os.path.join(self.dirpath, self.filename) self.filepath = filepath if filepath else os.path.join(self.dirpath, self.filename)
@property @property
def last_update(self): def last_update(self):
""" Retreive last update datetime of the telltall file """ """Retreive last update datetime of the telltall file"""
try: try:
return datetime.datetime.fromtimestamp( return datetime.datetime.fromtimestamp(os.stat(self.filepath).st_mtime)
os.stat(self.filepath).st_mtime
)
except FileNotFoundError: except FileNotFoundError:
log.info('Telltale file not found (%s)', self.filepath) log.info("Telltale file not found (%s)", self.filepath)
return None return None
def update(self): def update(self):
""" Update the telltale file """ """Update the telltale file"""
log.info('Update telltale file (%s)', self.filepath) log.info("Update telltale file (%s)", self.filepath)
try: try:
os.utime(self.filepath, None) os.utime(self.filepath, None)
except FileNotFoundError: except FileNotFoundError:
# pylint: disable=consider-using-with # pylint: disable=consider-using-with
open(self.filepath, 'a', encoding="utf-8").close() open(self.filepath, "a", encoding="utf-8").close()
def remove(self): def remove(self):
""" Remove the telltale file """ """Remove the telltale file"""
try: try:
os.remove(self.filepath) os.remove(self.filepath)
return True return True

View file

@ -1,2 +1,3 @@
[flake8] [flake8]
ignore = E501,W503 ignore = E501,W503
max-line-length = 100

View file

@ -1,76 +1,74 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- """Setuptools script"""
from setuptools import find_packages
from setuptools import setup
from setuptools import find_packages, setup
extras_require = { extras_require = {
'dev': [ "dev": [
'pytest', "pytest",
'mocker', "mocker",
'pytest-mock', "pytest-mock",
'pylint', "pylint == 2.15.10",
'flake8', "pre-commit",
], ],
'config': [ "config": [
'argcomplete', "argcomplete",
'keyring', "keyring",
'systemd-python', "systemd-python",
], ],
'ldap': [ "ldap": [
'python-ldap', "python-ldap",
'python-dateutil', "python-dateutil",
'pytz', "pytz",
], ],
'email': [ "email": [
'mako', "mako",
], ],
'pgsql': [ "pgsql": [
'psycopg2', "psycopg2",
], ],
'oracle': [ "oracle": [
'cx_Oracle', "cx_Oracle",
], ],
'mysql': [ "mysql": [
'mysqlclient', "mysqlclient",
], ],
'sftp': [ "sftp": [
'paramiko', "paramiko",
], ],
} }
install_requires = ['progressbar'] install_requires = ["progressbar"]
for extra, deps in extras_require.items(): for extra, deps in extras_require.items():
if extra != 'dev': if extra != "dev":
install_requires.extend(deps) install_requires.extend(deps)
version = '0.1' version = "0.1"
setup( setup(
name="mylib", name="mylib",
version=version, version=version,
description='A set of helpers small libs to make common tasks easier in my script development', description="A set of helpers small libs to make common tasks easier in my script development",
classifiers=[ classifiers=[
'Programming Language :: Python', "Programming Language :: Python",
], ],
install_requires=install_requires, install_requires=install_requires,
extras_require=extras_require, extras_require=extras_require,
author='Benjamin Renard', author="Benjamin Renard",
author_email='brenard@zionetrix.net', author_email="brenard@zionetrix.net",
url='https://gogs.zionetrix.net/bn8/python-mylib', url="https://gogs.zionetrix.net/bn8/python-mylib",
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
entry_points={ entry_points={
'console_scripts': [ "console_scripts": [
'mylib-test-email = mylib.scripts.email_test:main', "mylib-test-email = mylib.scripts.email_test:main",
'mylib-test-email-with-config = mylib.scripts.email_test_with_config:main', "mylib-test-email-with-config = mylib.scripts.email_test_with_config:main",
'mylib-test-map = mylib.scripts.map_test:main', "mylib-test-map = mylib.scripts.map_test:main",
'mylib-test-pbar = mylib.scripts.pbar_test:main', "mylib-test-pbar = mylib.scripts.pbar_test:main",
'mylib-test-report = mylib.scripts.report_test:main', "mylib-test-report = mylib.scripts.report_test:main",
'mylib-test-ldap = mylib.scripts.ldap_test:main', "mylib-test-ldap = mylib.scripts.ldap_test:main",
'mylib-test-sftp = mylib.scripts.sftp_test:main', "mylib-test-sftp = mylib.scripts.sftp_test:main",
], ],
}, },
) )

View file

@ -22,18 +22,11 @@ echo "Install package with dev dependencies using pip..."
$VENV/bin/python3 -m pip install -e ".[dev]" $QUIET_ARG $VENV/bin/python3 -m pip install -e ".[dev]" $QUIET_ARG
RES=0 RES=0
# Run tests
$VENV/bin/python3 -m pytest tests
[ $? -ne 0 ] && RES=1
# Run pylint # Run pre-commit
echo "Run pylint..." echo "Run pre-commit..."
$VENV/bin/pylint --extension-pkg-whitelist=cx_Oracle mylib tests source $VENV/bin/activate
[ $? -ne 0 ] && RES=1 pre-commit run --all-files
# Run flake8
echo "Run flake8..."
$VENV/bin/flake8 mylib tests
[ $? -ne 0 ] && RES=1 [ $? -ne 0 ] && RES=1
# Clean temporary venv # Clean temporary venv

View file

@ -2,31 +2,29 @@
# pylint: disable=global-variable-not-assigned # pylint: disable=global-variable-not-assigned
""" Tests on config lib """ """ Tests on config lib """
import configparser
import logging import logging
import os import os
import configparser
import pytest import pytest
from mylib.config import Config, ConfigSection from mylib.config import BooleanOption, Config, ConfigSection, StringOption
from mylib.config import BooleanOption
from mylib.config import StringOption
runned = {} runned = {}
def test_config_init_default_args(): def test_config_init_default_args():
appname = 'Test app' appname = "Test app"
config = Config(appname) config = Config(appname)
assert config.appname == appname assert config.appname == appname
assert config.version == '0.0' assert config.version == "0.0"
assert config.encoding == 'utf-8' assert config.encoding == "utf-8"
def test_config_init_custom_args(): def test_config_init_custom_args():
appname = 'Test app' appname = "Test app"
version = '1.43' version = "1.43"
encoding = 'ISO-8859-1' encoding = "ISO-8859-1"
config = Config(appname, version=version, encoding=encoding) config = Config(appname, version=version, encoding=encoding)
assert config.appname == appname assert config.appname == appname
assert config.version == version assert config.version == version
@ -34,8 +32,8 @@ def test_config_init_custom_args():
def test_add_section_default_args(): def test_add_section_default_args():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
section = config.add_section(name) section = config.add_section(name)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
assert config.sections[name] == section assert config.sections[name] == section
@ -45,9 +43,9 @@ def test_add_section_default_args():
def test_add_section_custom_args(): def test_add_section_custom_args():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
comment = 'Test' comment = "Test"
order = 20 order = 20
section = config.add_section(name, comment=comment, order=order) section = config.add_section(name, comment=comment, order=order)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
@ -57,47 +55,47 @@ def test_add_section_custom_args():
def test_add_section_with_callback(): def test_add_section_with_callback():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
global runned global runned
runned['test_add_section_with_callback'] = False runned["test_add_section_with_callback"] = False
def test_callback(loaded_config): def test_callback(loaded_config):
global runned global runned
assert loaded_config == config assert loaded_config == config
assert runned['test_add_section_with_callback'] is False assert runned["test_add_section_with_callback"] is False
runned['test_add_section_with_callback'] = True runned["test_add_section_with_callback"] = True
section = config.add_section(name, loaded_callback=test_callback) section = config.add_section(name, loaded_callback=test_callback)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
assert test_callback in config._loaded_callbacks assert test_callback in config._loaded_callbacks
assert runned['test_add_section_with_callback'] is False assert runned["test_add_section_with_callback"] is False
config.parse_arguments_options(argv=[], create=False) config.parse_arguments_options(argv=[], create=False)
assert runned['test_add_section_with_callback'] is True assert runned["test_add_section_with_callback"] is True
assert test_callback in config._loaded_callbacks_executed assert test_callback in config._loaded_callbacks_executed
# Try to execute again to verify callback is not runned again # Try to execute again to verify callback is not runned again
config._loaded() config._loaded()
def test_add_section_with_callback_already_loaded(): def test_add_section_with_callback_already_loaded():
config = Config('Test app') config = Config("Test app")
name = 'test_section' name = "test_section"
config.parse_arguments_options(argv=[], create=False) config.parse_arguments_options(argv=[], create=False)
global runned global runned
runned['test_add_section_with_callback_already_loaded'] = False runned["test_add_section_with_callback_already_loaded"] = False
def test_callback(loaded_config): def test_callback(loaded_config):
global runned global runned
assert loaded_config == config assert loaded_config == config
assert runned['test_add_section_with_callback_already_loaded'] is False assert runned["test_add_section_with_callback_already_loaded"] is False
runned['test_add_section_with_callback_already_loaded'] = True runned["test_add_section_with_callback_already_loaded"] = True
section = config.add_section(name, loaded_callback=test_callback) section = config.add_section(name, loaded_callback=test_callback)
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
assert runned['test_add_section_with_callback_already_loaded'] is True assert runned["test_add_section_with_callback_already_loaded"] is True
assert test_callback in config._loaded_callbacks assert test_callback in config._loaded_callbacks
assert test_callback in config._loaded_callbacks_executed assert test_callback in config._loaded_callbacks_executed
# Try to execute again to verify callback is not runned again # Try to execute again to verify callback is not runned again
@ -105,10 +103,10 @@ def test_add_section_with_callback_already_loaded():
def test_add_option_default_args(): def test_add_option_default_args():
config = Config('Test app') config = Config("Test app")
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
name = 'my_option' name = "my_option"
option = section.add_option(StringOption, name) option = section.add_option(StringOption, name)
assert isinstance(option, StringOption) assert isinstance(option, StringOption)
assert name in section.options and section.options[name] == option assert name in section.options and section.options[name] == option
@ -124,17 +122,17 @@ def test_add_option_default_args():
def test_add_option_custom_args(): def test_add_option_custom_args():
config = Config('Test app') config = Config("Test app")
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
name = 'my_option' name = "my_option"
kwargs = dict( kwargs = dict(
default='default value', default="default value",
comment='my comment', comment="my comment",
no_arg=True, no_arg=True,
arg='--my-option', arg="--my-option",
short_arg='-M', short_arg="-M",
arg_help='My help' arg_help="My help",
) )
option = section.add_option(StringOption, name, **kwargs) option = section.add_option(StringOption, name, **kwargs)
assert isinstance(option, StringOption) assert isinstance(option, StringOption)
@ -148,12 +146,12 @@ def test_add_option_custom_args():
def test_defined(): def test_defined():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
assert not config.defined(section_name, opt_name) assert not config.defined(section_name, opt_name)
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
section.add_option(StringOption, opt_name) section.add_option(StringOption, opt_name)
@ -161,29 +159,29 @@ def test_defined():
def test_isset(): def test_isset():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
assert not config.isset(section_name, opt_name) assert not config.isset(section_name, opt_name)
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
option = section.add_option(StringOption, opt_name) option = section.add_option(StringOption, opt_name)
assert not config.isset(section_name, opt_name) assert not config.isset(section_name, opt_name)
config.parse_arguments_options(argv=[option.parser_argument_name, 'value'], create=False) config.parse_arguments_options(argv=[option.parser_argument_name, "value"], create=False)
assert config.isset(section_name, opt_name) assert config.isset(section_name, opt_name)
def test_not_isset(): def test_not_isset():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
assert not config.isset(section_name, opt_name) assert not config.isset(section_name, opt_name)
section = config.add_section('my_section') section = config.add_section("my_section")
assert isinstance(section, ConfigSection) assert isinstance(section, ConfigSection)
section.add_option(StringOption, opt_name) section.add_option(StringOption, opt_name)
@ -195,11 +193,11 @@ def test_not_isset():
def test_get(): def test_get():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
opt_value = 'value' opt_value = "value"
section = config.add_section('my_section') section = config.add_section("my_section")
option = section.add_option(StringOption, opt_name) option = section.add_option(StringOption, opt_name)
config.parse_arguments_options(argv=[option.parser_argument_name, opt_value], create=False) config.parse_arguments_options(argv=[option.parser_argument_name, opt_value], create=False)
@ -207,11 +205,11 @@ def test_get():
def test_get_default(): def test_get_default():
config = Config('Test app') config = Config("Test app")
section_name = 'my_section' section_name = "my_section"
opt_name = 'my_option' opt_name = "my_option"
opt_default_value = 'value' opt_default_value = "value"
section = config.add_section('my_section') section = config.add_section("my_section")
section.add_option(StringOption, opt_name, default=opt_default_value) section.add_option(StringOption, opt_name, default=opt_default_value)
config.parse_arguments_options(argv=[], create=False) config.parse_arguments_options(argv=[], create=False)
@ -219,8 +217,8 @@ def test_get_default():
def test_logging_splited_stdout_stderr(capsys): def test_logging_splited_stdout_stderr(capsys):
config = Config('Test app') config = Config("Test app")
config.parse_arguments_options(argv=['-C', '-v'], create=False) config.parse_arguments_options(argv=["-C", "-v"], create=False)
info_msg = "[info]" info_msg = "[info]"
err_msg = "[error]" err_msg = "[error]"
logging.getLogger().info(info_msg) logging.getLogger().info(info_msg)
@ -239,9 +237,9 @@ def test_logging_splited_stdout_stderr(capsys):
@pytest.fixture() @pytest.fixture()
def config_with_file(tmpdir): def config_with_file(tmpdir):
config = Config('Test app') config = Config("Test app")
config_dir = tmpdir.mkdir('config') config_dir = tmpdir.mkdir("config")
config_file = config_dir.join('config.ini') config_file = config_dir.join("config.ini")
config.save(os.path.join(config_file.dirname, config_file.basename)) config.save(os.path.join(config_file.dirname, config_file.basename))
return config return config
@ -250,6 +248,7 @@ def generate_mock_input(expected_prompt, input_value):
def mock_input(self, prompt): # pylint: disable=unused-argument def mock_input(self, prompt): # pylint: disable=unused-argument
assert prompt == expected_prompt assert prompt == expected_prompt
return input_value return input_value
return mock_input return mock_input
@ -257,10 +256,9 @@ def generate_mock_input(expected_prompt, input_value):
def test_boolean_option_from_config(config_with_file): def test_boolean_option_from_config(config_with_file):
section = config_with_file.add_section('test') section = config_with_file.add_section("test")
default = True default = True
option = section.add_option( option = section.add_option(BooleanOption, "test_bool", default=default)
BooleanOption, 'test_bool', default=default)
config_with_file.save() config_with_file.save()
option.set(not default) option.set(not default)
@ -273,74 +271,76 @@ def test_boolean_option_from_config(config_with_file):
def test_boolean_option_ask_value(mocker): def test_boolean_option_ask_value(mocker):
config = Config('Test app') config = Config("Test app")
section = config.add_section('test') section = config.add_section("test")
name = 'test_bool' name = "test_bool"
option = section.add_option( option = section.add_option(BooleanOption, name, default=True)
BooleanOption, name, default=True)
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "y")
generate_mock_input(f'{name}: [Y/n] ', 'y')
) )
assert option.ask_value(set_it=False) is True assert option.ask_value(set_it=False) is True
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "Y")
generate_mock_input(f'{name}: [Y/n] ', 'Y')
) )
assert option.ask_value(set_it=False) is True assert option.ask_value(set_it=False) is True
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "")
generate_mock_input(f'{name}: [Y/n] ', '')
) )
assert option.ask_value(set_it=False) is True assert option.ask_value(set_it=False) is True
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "n")
generate_mock_input(f'{name}: [Y/n] ', 'n')
) )
assert option.ask_value(set_it=False) is False assert option.ask_value(set_it=False) is False
mocker.patch( mocker.patch(
'mylib.config.BooleanOption._get_user_input', "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "N")
generate_mock_input(f'{name}: [Y/n] ', 'N')
) )
assert option.ask_value(set_it=False) is False assert option.ask_value(set_it=False) is False
def test_boolean_option_to_config(): def test_boolean_option_to_config():
config = Config('Test app') config = Config("Test app")
section = config.add_section('test') section = config.add_section("test")
default = True default = True
option = section.add_option(BooleanOption, 'test_bool', default=default) option = section.add_option(BooleanOption, "test_bool", default=default)
assert option.to_config(True) == 'true' assert option.to_config(True) == "true"
assert option.to_config(False) == 'false' assert option.to_config(False) == "false"
def test_boolean_option_export_to_config(config_with_file): def test_boolean_option_export_to_config(config_with_file):
section = config_with_file.add_section('test') section = config_with_file.add_section("test")
name = 'test_bool' name = "test_bool"
comment = 'Test boolean' comment = "Test boolean"
default = True default = True
option = section.add_option( option = section.add_option(BooleanOption, name, default=default, comment=comment)
BooleanOption, name, default=default, comment=comment)
assert option.export_to_config() == f"""# {comment} assert (
option.export_to_config()
== f"""# {comment}
# Default: {str(default).lower()} # Default: {str(default).lower()}
# {name} = # {name} =
""" """
)
option.set(not default) option.set(not default)
assert option.export_to_config() == f"""# {comment} assert (
option.export_to_config()
== f"""# {comment}
# Default: {str(default).lower()} # Default: {str(default).lower()}
{name} = {str(not default).lower()} {name} = {str(not default).lower()}
""" """
)
option.set(default) option.set(default)
assert option.export_to_config() == f"""# {comment} assert (
option.export_to_config()
== f"""# {comment}
# Default: {str(default).lower()} # Default: {str(default).lower()}
# {name} = # {name} =
""" """
)

View file

@ -2,16 +2,17 @@
""" Tests on opening hours helpers """ """ Tests on opening hours helpers """
import pytest import pytest
from MySQLdb._exceptions import Error from MySQLdb._exceptions import Error
from mylib.mysql import MyDB from mylib.mysql import MyDB
class FakeMySQLdbCursor: class FakeMySQLdbCursor:
""" Fake MySQLdb cursor """ """Fake MySQLdb cursor"""
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
self.expected_sql = expected_sql self.expected_sql = expected_sql
self.expected_params = expected_params self.expected_params = expected_params
self.expected_return = expected_return self.expected_return = expected_return
@ -20,13 +21,25 @@ class FakeMySQLdbCursor:
def execute(self, sql, params=None): def execute(self, sql, params=None):
if self.expected_exception: if self.expected_exception:
raise Error(f'{self}.execute({sql}, {params}): expected exception') raise Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith('select '): if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
return self.expected_return return self.expected_return
@property @property
@ -39,21 +52,19 @@ class FakeMySQLdbCursor:
def fetchall(self): def fetchall(self):
if isinstance(self.expected_return, list): if isinstance(self.expected_return, list):
return ( return (
list(row.values()) list(row.values()) if isinstance(row, dict) else row for row in self.expected_return
if isinstance(row, dict) else row
for row in self.expected_return
) )
return self.expected_return return self.expected_return
def __repr__(self): def __repr__(self):
return ( return (
f'FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, ' f"FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, "
f'{self.expected_return}, {self.expected_just_try})' f"{self.expected_return}, {self.expected_just_try})"
) )
class FakeMySQLdb: class FakeMySQLdb:
""" Fake MySQLdb connection """ """Fake MySQLdb connection"""
expected_sql = None expected_sql = None
expected_params = None expected_params = None
@ -63,11 +74,14 @@ class FakeMySQLdb:
just_try = False just_try = False
def __init__(self, **kwargs): def __init__(self, **kwargs):
allowed_kwargs = dict(db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool) allowed_kwargs = dict(
db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool
)
for arg, value in kwargs.items(): for arg, value in kwargs.items():
assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"' assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"'
assert isinstance(value, allowed_kwargs[arg]), \ assert isinstance(
f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})' value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value) setattr(self, arg, value)
def close(self): def close(self):
@ -75,9 +89,11 @@ class FakeMySQLdb:
def cursor(self): def cursor(self):
return FakeMySQLdbCursor( return FakeMySQLdbCursor(
self.expected_sql, self.expected_params, self.expected_sql,
self.expected_return, self.expected_just_try or self.just_try, self.expected_params,
self.expected_exception self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
) )
def commit(self): def commit(self):
@ -105,19 +121,19 @@ def fake_mysqldb_connect_just_try(**kwargs):
@pytest.fixture @pytest.fixture
def test_mydb(): def test_mydb():
return MyDB('127.0.0.1', 'user', 'password', 'dbname') return MyDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_mydb(mocker): def fake_mydb(mocker):
mocker.patch('MySQLdb.connect', fake_mysqldb_connect) mocker.patch("MySQLdb.connect", fake_mysqldb_connect)
return MyDB('127.0.0.1', 'user', 'password', 'dbname') return MyDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_just_try_mydb(mocker): def fake_just_try_mydb(mocker):
mocker.patch('MySQLdb.connect', fake_mysqldb_connect_just_try) mocker.patch("MySQLdb.connect", fake_mysqldb_connect_just_try)
return MyDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True) return MyDB("127.0.0.1", "user", "password", "dbname", just_try=True)
@pytest.fixture @pytest.fixture
@ -132,13 +148,22 @@ def fake_connected_just_try_mydb(fake_just_try_mydb):
return fake_just_try_mydb return fake_just_try_mydb
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs): def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
return expected_return return expected_return
return mock_args return mock_args
@ -146,13 +171,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode" assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
return expected_return return expected_return
return mock_doSQL return mock_doSQL
@ -166,15 +200,11 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter(): def test_combine_params_with_to_add_parameter():
assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict( assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs(): def test_combine_params_with_kargs():
assert MyDB._combine_params(dict(test1=1), test2=2) == dict( assert MyDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter(): def test_combine_params_with_kargs_and_to_add_parameter():
@ -184,47 +214,40 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved(): def test_format_where_clauses_params_are_preserved():
args = ('test = test', dict(test1=1)) args = ("test = test", dict(test1=1))
assert MyDB._format_where_clauses(*args) == args assert MyDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw(): def test_format_where_clauses_raw():
assert MyDB._format_where_clauses('test = test') == (('test = test'), {}) assert MyDB._format_where_clauses("test = test") == ("test = test", {})
def test_format_where_clauses_tuple_clause_with_params(): def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ( where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2))
'test1 = %(test1)s AND test2 = %(test2)s',
dict(test1=1, test2=2)
)
assert MyDB._format_where_clauses(where_clauses) == where_clauses assert MyDB._format_where_clauses(where_clauses) == where_clauses
def test_format_where_clauses_dict(): def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert MyDB._format_where_clauses(where_clauses) == ( assert MyDB._format_where_clauses(where_clauses) == (
'`test1` = %(test1)s AND `test2` = %(test2)s', "`test1` = %(test1)s AND `test2` = %(test2)s",
where_clauses where_clauses,
) )
def test_format_where_clauses_combined_types(): def test_format_where_clauses_combined_types():
where_clauses = ( where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4))
'test1 = 1',
('test2 LIKE %(test2)s', dict(test2=2)),
dict(test3=3, test4=4)
)
assert MyDB._format_where_clauses(where_clauses) == ( assert MyDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s', "test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s",
dict(test2=2, test3=3, test4=4) dict(test2=2, test3=3, test4=4),
) )
def test_format_where_clauses_with_where_op(): def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert MyDB._format_where_clauses(where_clauses, where_op='OR') == ( assert MyDB._format_where_clauses(where_clauses, where_op="OR") == (
'`test1` = %(test1)s OR `test2` = %(test2)s', "`test1` = %(test1)s OR `test2` = %(test2)s",
where_clauses where_clauses,
) )
@ -232,8 +255,8 @@ def test_add_where_clauses():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert MyDB._add_where_clauses(sql, None, where_clauses) == ( assert MyDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s', sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s",
where_clauses where_clauses,
) )
@ -242,106 +265,102 @@ def test_add_where_clauses_preserved_params():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
params = dict(fake1=1) params = dict(fake1=1)
assert MyDB._add_where_clauses(sql, params.copy(), where_clauses) == ( assert MyDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s', sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s",
dict(**where_clauses, **params) dict(**where_clauses, **params),
) )
def test_add_where_clauses_with_op(): def test_add_where_clauses_with_op():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = ('test1=1', 'test2=2') where_clauses = ("test1=1", "test2=2")
assert MyDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( assert MyDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + ' WHERE test1=1 OR test2=2', sql + " WHERE test1=1 OR test2=2",
{} {},
) )
def test_add_where_clauses_with_duplicated_field(): def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=%(test1)s" sql = "UPDATE table SET test1=%(test1)s"
params = dict(test1='new_value') params = dict(test1="new_value")
where_clauses = dict(test1='where_value') where_clauses = dict(test1="where_value")
assert MyDB._add_where_clauses(sql, params, where_clauses) == ( assert MyDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE `test1` = %(test1_1)s', sql + " WHERE `test1` = %(test1_1)s",
dict(test1='new_value', test1_1='where_value') dict(test1="new_value", test1_1="where_value"),
) )
def test_quote_table_name(): def test_quote_table_name():
assert MyDB._quote_table_name("mytable") == '`mytable`' assert MyDB._quote_table_name("mytable") == "`mytable`"
assert MyDB._quote_table_name("myschema.mytable") == '`myschema`.`mytable`' assert MyDB._quote_table_name("myschema.mytable") == "`myschema`.`mytable`"
def test_insert(mocker, test_mydb): def test_insert(mocker, test_mydb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSQL', "mylib.mysql.MyDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)', "INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)", values
values ),
)
) )
assert test_mydb.insert('mytable', values) assert test_mydb.insert("mytable", values)
def test_insert_just_try(mocker, test_mydb): def test_insert_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.insert('mytable', dict(test1=1, test2=2), just_try=True) assert test_mydb.insert("mytable", dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_mydb): def test_update(mocker, test_mydb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSQL', "mylib.mysql.MyDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` = %(test3)s AND `test4` = %(test4)s', "UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` ="
dict(**values, **where_clauses) " %(test3)s AND `test4` = %(test4)s",
) dict(**values, **where_clauses),
),
) )
assert test_mydb.update('mytable', values, where_clauses) assert test_mydb.update("mytable", values, where_clauses)
def test_update_just_try(mocker, test_mydb): def test_update_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.update('mytable', dict(test1=1, test2=2), None, just_try=True) assert test_mydb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_mydb): def test_delete(mocker, test_mydb):
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSQL', "mylib.mysql.MyDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s', "DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s", where_clauses
where_clauses ),
)
) )
assert test_mydb.delete('mytable', where_clauses) assert test_mydb.delete("mytable", where_clauses)
def test_delete_just_try(mocker, test_mydb): def test_delete_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try)
assert test_mydb.delete('mytable', None, just_try=True) assert test_mydb.delete("mytable", None, just_try=True)
def test_truncate(mocker, test_mydb): def test_truncate(mocker, test_mydb):
mocker.patch( mocker.patch("mylib.mysql.MyDB.doSQL", generate_mock_doSQL("TRUNCATE TABLE `mytable`", None))
'mylib.mysql.MyDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE `mytable`', None)
)
assert test_mydb.truncate('mytable') assert test_mydb.truncate("mytable")
def test_truncate_just_try(mocker, test_mydb): def test_truncate_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try)
assert test_mydb.truncate('mytable', just_try=True) assert test_mydb.truncate("mytable", just_try=True)
def test_select(mocker, test_mydb): def test_select(mocker, test_mydb):
fields = ('field1', 'field2') fields = ("field1", "field2")
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
expected_return = [ expected_return = [
dict(field1=1, field2=2), dict(field1=1, field2=2),
@ -349,30 +368,28 @@ def test_select(mocker, test_mydb):
] ]
order_by = "field1, DESC" order_by = "field1, DESC"
mocker.patch( mocker.patch(
'mylib.mysql.MyDB.doSelect', "mylib.mysql.MyDB.doSelect",
generate_mock_doSQL( generate_mock_doSQL(
'SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` = %(test4)s ORDER BY ' + order_by, "SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` ="
where_clauses, expected_return " %(test4)s ORDER BY " + order_by,
) where_clauses,
expected_return,
),
) )
assert test_mydb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return assert test_mydb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_mydb): def test_select_without_field_and_order_by(mocker, test_mydb):
mocker.patch( mocker.patch("mylib.mysql.MyDB.doSelect", generate_mock_doSQL("SELECT * FROM `mytable`"))
'mylib.mysql.MyDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM `mytable`'
)
)
assert test_mydb.select('mytable') assert test_mydb.select("mytable")
def test_select_just_try(mocker, test_mydb): def test_select_just_try(mocker, test_mydb):
mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try)
assert test_mydb.select('mytable', None, None, just_try=True) assert test_mydb.select("mytable", None, None, just_try=True)
# #
# Tests on main methods # Tests on main methods
@ -389,12 +406,7 @@ def test_connect(mocker, test_mydb):
use_unicode=True, use_unicode=True,
) )
mocker.patch( mocker.patch("MySQLdb.connect", generate_mock_args(expected_kwargs=expected_kwargs))
'MySQLdb.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_mydb.connect() assert test_mydb.connect()
@ -408,48 +420,61 @@ def test_close_connected(fake_connected_mydb):
def test_doSQL(fake_connected_mydb): def test_doSQL(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s' fake_connected_mydb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s"
fake_connected_mydb._conn.expected_params = dict(test1=1) fake_connected_mydb._conn.expected_params = dict(test1=1)
fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) fake_connected_mydb.doSQL(
fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params
)
def test_doSQL_without_params(fake_connected_mydb): def test_doSQL_without_params(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'DELETE FROM table' fake_connected_mydb._conn.expected_sql = "DELETE FROM table"
fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql) fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_mydb): def test_doSQL_just_try(fake_connected_just_try_mydb):
assert fake_connected_just_try_mydb.doSQL('DELETE FROM table') assert fake_connected_just_try_mydb.doSQL("DELETE FROM table")
def test_doSQL_on_exception(fake_connected_mydb): def test_doSQL_on_exception(fake_connected_mydb):
fake_connected_mydb._conn.expected_exception = True fake_connected_mydb._conn.expected_exception = True
assert fake_connected_mydb.doSQL('DELETE FROM table') is False assert fake_connected_mydb.doSQL("DELETE FROM table") is False
def test_doSelect(fake_connected_mydb): def test_doSelect(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_mydb._conn.expected_params = dict(test1=1) fake_connected_mydb._conn.expected_params = dict(test1=1)
fake_connected_mydb._conn.expected_return = [dict(test1=1)] fake_connected_mydb._conn.expected_return = [dict(test1=1)]
assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) == fake_connected_mydb._conn.expected_return assert (
fake_connected_mydb.doSelect(
fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params
)
== fake_connected_mydb._conn.expected_return
)
def test_doSelect_without_params(fake_connected_mydb): def test_doSelect_without_params(fake_connected_mydb):
fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table' fake_connected_mydb._conn.expected_sql = "SELECT * FROM table"
fake_connected_mydb._conn.expected_return = [dict(test1=1)] fake_connected_mydb._conn.expected_return = [dict(test1=1)]
assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql) == fake_connected_mydb._conn.expected_return assert (
fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql)
== fake_connected_mydb._conn.expected_return
)
def test_doSelect_on_exception(fake_connected_mydb): def test_doSelect_on_exception(fake_connected_mydb):
fake_connected_mydb._conn.expected_exception = True fake_connected_mydb._conn.expected_exception = True
assert fake_connected_mydb.doSelect('SELECT * FROM table') is False assert fake_connected_mydb.doSelect("SELECT * FROM table") is False
def test_doSelect_just_try(fake_connected_just_try_mydb): def test_doSelect_just_try(fake_connected_just_try_mydb):
fake_connected_just_try_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_just_try_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_just_try_mydb._conn.expected_params = dict(test1=1) fake_connected_just_try_mydb._conn.expected_params = dict(test1=1)
fake_connected_just_try_mydb._conn.expected_return = [dict(test1=1)] fake_connected_just_try_mydb._conn.expected_return = [dict(test1=1)]
assert fake_connected_just_try_mydb.doSelect( assert (
fake_connected_just_try_mydb._conn.expected_sql, fake_connected_just_try_mydb.doSelect(
fake_connected_just_try_mydb._conn.expected_params fake_connected_just_try_mydb._conn.expected_sql,
) == fake_connected_just_try_mydb._conn.expected_return fake_connected_just_try_mydb._conn.expected_params,
)
== fake_connected_just_try_mydb._conn.expected_return
)

View file

@ -2,6 +2,7 @@
""" Tests on opening hours helpers """ """ Tests on opening hours helpers """
import datetime import datetime
import pytest import pytest
from mylib import opening_hours from mylib import opening_hours
@ -12,14 +13,16 @@ from mylib import opening_hours
def test_parse_exceptional_closures_one_day_without_time_period(): def test_parse_exceptional_closures_one_day_without_time_period():
assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [{'days': [datetime.date(2017, 9, 22)], 'hours_periods': []}] assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [
{"days": [datetime.date(2017, 9, 22)], "hours_periods": []}
]
def test_parse_exceptional_closures_one_day_with_time_period(): def test_parse_exceptional_closures_one_day_with_time_period():
assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30"]) == [ assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30"]) == [
{ {
'days': [datetime.date(2017, 11, 26)], "days": [datetime.date(2017, 11, 26)],
'hours_periods': [{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}] "hours_periods": [{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}],
} }
] ]
@ -27,11 +30,11 @@ def test_parse_exceptional_closures_one_day_with_time_period():
def test_parse_exceptional_closures_one_day_with_multiple_time_periods(): def test_parse_exceptional_closures_one_day_with_multiple_time_periods():
assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30 14h-18h"]) == [ assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30 14h-18h"]) == [
{ {
'days': [datetime.date(2017, 11, 26)], "days": [datetime.date(2017, 11, 26)],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
] ],
} }
] ]
@ -39,8 +42,12 @@ def test_parse_exceptional_closures_one_day_with_multiple_time_periods():
def test_parse_exceptional_closures_full_days_period(): def test_parse_exceptional_closures_full_days_period():
assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017"]) == [ assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017"]) == [
{ {
'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)], "days": [
'hours_periods': [] datetime.date(2017, 9, 20),
datetime.date(2017, 9, 21),
datetime.date(2017, 9, 22),
],
"hours_periods": [],
} }
] ]
@ -53,8 +60,12 @@ def test_parse_exceptional_closures_invalid_days_period():
def test_parse_exceptional_closures_days_period_with_time_period(): def test_parse_exceptional_closures_days_period_with_time_period():
assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017 9h-12h"]) == [ assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017 9h-12h"]) == [
{ {
'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)], "days": [
'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}] datetime.date(2017, 9, 20),
datetime.date(2017, 9, 21),
datetime.date(2017, 9, 22),
],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}],
} }
] ]
@ -70,31 +81,38 @@ def test_parse_exceptional_closures_invalid_time_period():
def test_parse_exceptional_closures_multiple_periods(): def test_parse_exceptional_closures_multiple_periods():
assert opening_hours.parse_exceptional_closures(["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]) == [ assert opening_hours.parse_exceptional_closures(
["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]
) == [
{ {
'days': [ "days": [
datetime.date(2017, 9, 20), datetime.date(2017, 9, 20),
datetime.date(2017, 11, 25), datetime.date(2017, 11, 25),
datetime.date(2017, 11, 26), datetime.date(2017, 11, 26),
], ],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
] ],
} }
] ]
# #
# Tests on parse_normal_opening_hours() # Tests on parse_normal_opening_hours()
# #
def test_parse_normal_opening_hours_one_day(): def test_parse_normal_opening_hours_one_day():
assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [{'days': ["jeudi"], 'hours_periods': []}] assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [
{"days": ["jeudi"], "hours_periods": []}
]
def test_parse_normal_opening_hours_multiple_days(): def test_parse_normal_opening_hours_multiple_days():
assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [{'days': ["lundi", "jeudi"], 'hours_periods': []}] assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [
{"days": ["lundi", "jeudi"], "hours_periods": []}
]
def test_parse_normal_opening_hours_invalid_day(): def test_parse_normal_opening_hours_invalid_day():
@ -104,13 +122,17 @@ def test_parse_normal_opening_hours_invalid_day():
def test_parse_normal_opening_hours_one_days_period(): def test_parse_normal_opening_hours_one_days_period():
assert opening_hours.parse_normal_opening_hours(["lundi-jeudi"]) == [ assert opening_hours.parse_normal_opening_hours(["lundi-jeudi"]) == [
{'days': ["lundi", "mardi", "mercredi", "jeudi"], 'hours_periods': []} {"days": ["lundi", "mardi", "mercredi", "jeudi"], "hours_periods": []}
] ]
def test_parse_normal_opening_hours_one_day_with_one_time_period(): def test_parse_normal_opening_hours_one_day_with_one_time_period():
assert opening_hours.parse_normal_opening_hours(["jeudi 9h-12h"]) == [ assert opening_hours.parse_normal_opening_hours(["jeudi 9h-12h"]) == [
{'days': ["jeudi"], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}]}] {
"days": ["jeudi"],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}],
}
]
def test_parse_normal_opening_hours_invalid_days_period(): def test_parse_normal_opening_hours_invalid_days_period():
@ -122,7 +144,10 @@ def test_parse_normal_opening_hours_invalid_days_period():
def test_parse_normal_opening_hours_one_time_period(): def test_parse_normal_opening_hours_one_time_period():
assert opening_hours.parse_normal_opening_hours(["9h-18h30"]) == [ assert opening_hours.parse_normal_opening_hours(["9h-18h30"]) == [
{'days': [], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(18, 30)}]} {
"days": [],
"hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(18, 30)}],
}
] ]
@ -132,48 +157,60 @@ def test_parse_normal_opening_hours_invalid_time_period():
def test_parse_normal_opening_hours_multiple_periods(): def test_parse_normal_opening_hours_multiple_periods():
assert opening_hours.parse_normal_opening_hours(["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]) == [ assert opening_hours.parse_normal_opening_hours(
["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]
) == [
{ {
'days': ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi'], "days": ["lundi", "mardi", "mercredi", "jeudi", "vendredi"],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)},
{'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)},
] ],
}, },
{ {
'days': ['samedi'], "days": ["samedi"],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(18, 0)}, {"start": datetime.time(9, 30), "stop": datetime.time(18, 0)},
] ],
}, },
{ {
'days': ['dimanche'], "days": ["dimanche"],
'hours_periods': [ "hours_periods": [
{'start': datetime.time(9, 30), 'stop': datetime.time(12, 0)}, {"start": datetime.time(9, 30), "stop": datetime.time(12, 0)},
] ],
}, },
] ]
# #
# Tests on is_closed # Tests on is_closed
# #
exceptional_closures = ["22/09/2017", "20/09/2017-22/09/2017", "20/09/2017-22/09/2017 18/09/2017", "25/11/2017", "26/11/2017 9h30-12h30"] exceptional_closures = [
normal_opening_hours = ["lundi-mardi jeudi 9h30-12h30 14h-16h30", "mercredi vendredi 9h30-12h30 14h-17h"] "22/09/2017",
"20/09/2017-22/09/2017",
"20/09/2017-22/09/2017 18/09/2017",
"25/11/2017",
"26/11/2017 9h30-12h30",
]
normal_opening_hours = [
"lundi-mardi jeudi 9h30-12h30 14h-16h30",
"mercredi vendredi 9h30-12h30 14h-17h",
]
nonworking_public_holidays = [ nonworking_public_holidays = [
'1janvier', "1janvier",
'paques', "paques",
'lundi_paques', "lundi_paques",
'1mai', "1mai",
'8mai', "8mai",
'jeudi_ascension', "jeudi_ascension",
'lundi_pentecote', "lundi_pentecote",
'14juillet', "14juillet",
'15aout', "15aout",
'1novembre', "1novembre",
'11novembre', "11novembre",
'noel', "noel",
] ]
@ -182,12 +219,8 @@ def test_is_closed_when_normaly_closed_by_hour():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 1, 20, 15) when=datetime.datetime(2017, 5, 1, 20, 15),
) == { ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_on_exceptional_closure_full_day(): def test_is_closed_on_exceptional_closure_full_day():
@ -195,12 +228,8 @@ def test_is_closed_on_exceptional_closure_full_day():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 9, 22, 14, 15) when=datetime.datetime(2017, 9, 22, 14, 15),
) == { ) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": True}
'closed': True,
'exceptional_closure': True,
'exceptional_closure_all_day': True
}
def test_is_closed_on_exceptional_closure_day(): def test_is_closed_on_exceptional_closure_day():
@ -208,12 +237,8 @@ def test_is_closed_on_exceptional_closure_day():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 11, 26, 10, 30) when=datetime.datetime(2017, 11, 26, 10, 30),
) == { ) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': True,
'exceptional_closure_all_day': False
}
def test_is_closed_on_nonworking_public_holidays(): def test_is_closed_on_nonworking_public_holidays():
@ -221,12 +246,8 @@ def test_is_closed_on_nonworking_public_holidays():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 1, 1, 10, 30) when=datetime.datetime(2017, 1, 1, 10, 30),
) == { ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_when_normaly_closed_by_day(): def test_is_closed_when_normaly_closed_by_day():
@ -234,12 +255,8 @@ def test_is_closed_when_normaly_closed_by_day():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 6, 14, 15) when=datetime.datetime(2017, 5, 6, 14, 15),
) == { ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': True,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_is_closed_when_normaly_opened(): def test_is_closed_when_normaly_opened():
@ -247,12 +264,8 @@ def test_is_closed_when_normaly_opened():
normal_opening_hours_values=normal_opening_hours, normal_opening_hours_values=normal_opening_hours,
exceptional_closures_values=exceptional_closures, exceptional_closures_values=exceptional_closures,
nonworking_public_holidays_values=nonworking_public_holidays, nonworking_public_holidays_values=nonworking_public_holidays,
when=datetime.datetime(2017, 5, 2, 15, 15) when=datetime.datetime(2017, 5, 2, 15, 15),
) == { ) == {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False}
'closed': False,
'exceptional_closure': False,
'exceptional_closure_all_day': False
}
def test_easter_date(): def test_easter_date():
@ -272,18 +285,18 @@ def test_easter_date():
def test_nonworking_french_public_days_of_the_year(): def test_nonworking_french_public_days_of_the_year():
assert opening_hours.nonworking_french_public_days_of_the_year(2021) == { assert opening_hours.nonworking_french_public_days_of_the_year(2021) == {
'1janvier': datetime.date(2021, 1, 1), "1janvier": datetime.date(2021, 1, 1),
'paques': datetime.date(2021, 4, 4), "paques": datetime.date(2021, 4, 4),
'lundi_paques': datetime.date(2021, 4, 5), "lundi_paques": datetime.date(2021, 4, 5),
'1mai': datetime.date(2021, 5, 1), "1mai": datetime.date(2021, 5, 1),
'8mai': datetime.date(2021, 5, 8), "8mai": datetime.date(2021, 5, 8),
'jeudi_ascension': datetime.date(2021, 5, 13), "jeudi_ascension": datetime.date(2021, 5, 13),
'pentecote': datetime.date(2021, 5, 23), "pentecote": datetime.date(2021, 5, 23),
'lundi_pentecote': datetime.date(2021, 5, 24), "lundi_pentecote": datetime.date(2021, 5, 24),
'14juillet': datetime.date(2021, 7, 14), "14juillet": datetime.date(2021, 7, 14),
'15aout': datetime.date(2021, 8, 15), "15aout": datetime.date(2021, 8, 15),
'1novembre': datetime.date(2021, 11, 1), "1novembre": datetime.date(2021, 11, 1),
'11novembre': datetime.date(2021, 11, 11), "11novembre": datetime.date(2021, 11, 11),
'noel': datetime.date(2021, 12, 25), "noel": datetime.date(2021, 12, 25),
'saint_etienne': datetime.date(2021, 12, 26) "saint_etienne": datetime.date(2021, 12, 26),
} }

View file

@ -8,9 +8,11 @@ from mylib.oracle import OracleDB
class FakeCXOracleCursor: class FakeCXOracleCursor:
""" Fake cx_Oracle cursor """ """Fake cx_Oracle cursor"""
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
self.expected_sql = expected_sql self.expected_sql = expected_sql
self.expected_params = expected_params self.expected_params = expected_params
self.expected_return = expected_return self.expected_return = expected_return
@ -21,13 +23,25 @@ class FakeCXOracleCursor:
def execute(self, sql, **params): def execute(self, sql, **params):
assert self.opened assert self.opened
if self.expected_exception: if self.expected_exception:
raise cx_Oracle.Error(f'{self}.execute({sql}, {params}): expected exception') raise cx_Oracle.Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith('select '): if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
return self.expected_return return self.expected_return
def fetchall(self): def fetchall(self):
@ -43,13 +57,13 @@ class FakeCXOracleCursor:
def __repr__(self): def __repr__(self):
return ( return (
f'FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, ' f"FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, "
f'{self.expected_return}, {self.expected_just_try})' f"{self.expected_return}, {self.expected_just_try})"
) )
class FakeCXOracle: class FakeCXOracle:
""" Fake cx_Oracle connection """ """Fake cx_Oracle connection"""
expected_sql = None expected_sql = None
expected_params = {} expected_params = {}
@ -62,7 +76,9 @@ class FakeCXOracle:
allowed_kwargs = dict(dsn=str, user=str, password=(str, None)) allowed_kwargs = dict(dsn=str, user=str, password=(str, None))
for arg, value in kwargs.items(): for arg, value in kwargs.items():
assert arg in allowed_kwargs, f"Invalid arg {arg}='{value}'" assert arg in allowed_kwargs, f"Invalid arg {arg}='{value}'"
assert isinstance(value, allowed_kwargs[arg]), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})" assert isinstance(
value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value) setattr(self, arg, value)
def close(self): def close(self):
@ -70,9 +86,11 @@ class FakeCXOracle:
def cursor(self): def cursor(self):
return FakeCXOracleCursor( return FakeCXOracleCursor(
self.expected_sql, self.expected_params, self.expected_sql,
self.expected_return, self.expected_just_try or self.just_try, self.expected_params,
self.expected_exception self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
) )
def commit(self): def commit(self):
@ -100,19 +118,19 @@ def fake_cxoracle_connect_just_try(**kwargs):
@pytest.fixture @pytest.fixture
def test_oracledb(): def test_oracledb():
return OracleDB('127.0.0.1/dbname', 'user', 'password') return OracleDB("127.0.0.1/dbname", "user", "password")
@pytest.fixture @pytest.fixture
def fake_oracledb(mocker): def fake_oracledb(mocker):
mocker.patch('cx_Oracle.connect', fake_cxoracle_connect) mocker.patch("cx_Oracle.connect", fake_cxoracle_connect)
return OracleDB('127.0.0.1/dbname', 'user', 'password') return OracleDB("127.0.0.1/dbname", "user", "password")
@pytest.fixture @pytest.fixture
def fake_just_try_oracledb(mocker): def fake_just_try_oracledb(mocker):
mocker.patch('cx_Oracle.connect', fake_cxoracle_connect_just_try) mocker.patch("cx_Oracle.connect", fake_cxoracle_connect_just_try)
return OracleDB('127.0.0.1/dbname', 'user', 'password', just_try=True) return OracleDB("127.0.0.1/dbname", "user", "password", just_try=True)
@pytest.fixture @pytest.fixture
@ -127,13 +145,22 @@ def fake_connected_just_try_oracledb(fake_just_try_oracledb):
return fake_just_try_oracledb return fake_just_try_oracledb
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs): def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
return expected_return return expected_return
return mock_args return mock_args
@ -141,13 +168,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode" assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
return expected_return return expected_return
return mock_doSQL return mock_doSQL
@ -161,15 +197,11 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter(): def test_combine_params_with_to_add_parameter():
assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict( assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs(): def test_combine_params_with_kargs():
assert OracleDB._combine_params(dict(test1=1), test2=2) == dict( assert OracleDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter(): def test_combine_params_with_kargs_and_to_add_parameter():
@ -179,19 +211,16 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved(): def test_format_where_clauses_params_are_preserved():
args = ('test = test', dict(test1=1)) args = ("test = test", dict(test1=1))
assert OracleDB._format_where_clauses(*args) == args assert OracleDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw(): def test_format_where_clauses_raw():
assert OracleDB._format_where_clauses('test = test') == (('test = test'), {}) assert OracleDB._format_where_clauses("test = test") == ("test = test", {})
def test_format_where_clauses_tuple_clause_with_params(): def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ( where_clauses = ("test1 = :test1 AND test2 = :test2", dict(test1=1, test2=2))
'test1 = :test1 AND test2 = :test2',
dict(test1=1, test2=2)
)
assert OracleDB._format_where_clauses(where_clauses) == where_clauses assert OracleDB._format_where_clauses(where_clauses) == where_clauses
@ -199,27 +228,23 @@ def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert OracleDB._format_where_clauses(where_clauses) == ( assert OracleDB._format_where_clauses(where_clauses) == (
'"test1" = :test1 AND "test2" = :test2', '"test1" = :test1 AND "test2" = :test2',
where_clauses where_clauses,
) )
def test_format_where_clauses_combined_types(): def test_format_where_clauses_combined_types():
where_clauses = ( where_clauses = ("test1 = 1", ("test2 LIKE :test2", dict(test2=2)), dict(test3=3, test4=4))
'test1 = 1',
('test2 LIKE :test2', dict(test2=2)),
dict(test3=3, test4=4)
)
assert OracleDB._format_where_clauses(where_clauses) == ( assert OracleDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE :test2 AND "test3" = :test3 AND "test4" = :test4', 'test1 = 1 AND test2 LIKE :test2 AND "test3" = :test3 AND "test4" = :test4',
dict(test2=2, test3=3, test4=4) dict(test2=2, test3=3, test4=4),
) )
def test_format_where_clauses_with_where_op(): def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert OracleDB._format_where_clauses(where_clauses, where_op='OR') == ( assert OracleDB._format_where_clauses(where_clauses, where_op="OR") == (
'"test1" = :test1 OR "test2" = :test2', '"test1" = :test1 OR "test2" = :test2',
where_clauses where_clauses,
) )
@ -228,7 +253,7 @@ def test_add_where_clauses():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert OracleDB._add_where_clauses(sql, None, where_clauses) == ( assert OracleDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE "test1" = :test1 AND "test2" = :test2', sql + ' WHERE "test1" = :test1 AND "test2" = :test2',
where_clauses where_clauses,
) )
@ -238,26 +263,26 @@ def test_add_where_clauses_preserved_params():
params = dict(fake1=1) params = dict(fake1=1)
assert OracleDB._add_where_clauses(sql, params.copy(), where_clauses) == ( assert OracleDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE "test1" = :test1 AND "test2" = :test2', sql + ' WHERE "test1" = :test1 AND "test2" = :test2',
dict(**where_clauses, **params) dict(**where_clauses, **params),
) )
def test_add_where_clauses_with_op(): def test_add_where_clauses_with_op():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = ('test1=1', 'test2=2') where_clauses = ("test1=1", "test2=2")
assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + ' WHERE test1=1 OR test2=2', sql + " WHERE test1=1 OR test2=2",
{} {},
) )
def test_add_where_clauses_with_duplicated_field(): def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=:test1" sql = "UPDATE table SET test1=:test1"
params = dict(test1='new_value') params = dict(test1="new_value")
where_clauses = dict(test1='where_value') where_clauses = dict(test1="where_value")
assert OracleDB._add_where_clauses(sql, params, where_clauses) == ( assert OracleDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE "test1" = :test1_1', sql + ' WHERE "test1" = :test1_1',
dict(test1='new_value', test1_1='where_value') dict(test1="new_value", test1_1="where_value"),
) )
@ -269,74 +294,72 @@ def test_quote_table_name():
def test_insert(mocker, test_oracledb): def test_insert(mocker, test_oracledb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', 'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', values
values ),
)
) )
assert test_oracledb.insert('mytable', values) assert test_oracledb.insert("mytable", values)
def test_insert_just_try(mocker, test_oracledb): def test_insert_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.insert('mytable', dict(test1=1, test2=2), just_try=True) assert test_oracledb.insert("mytable", dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_oracledb): def test_update(mocker, test_oracledb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND "test4" = :test4', 'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND'
dict(**values, **where_clauses) ' "test4" = :test4',
) dict(**values, **where_clauses),
),
) )
assert test_oracledb.update('mytable', values, where_clauses) assert test_oracledb.update("mytable", values, where_clauses)
def test_update_just_try(mocker, test_oracledb): def test_update_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.update('mytable', dict(test1=1, test2=2), None, just_try=True) assert test_oracledb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_oracledb): def test_delete(mocker, test_oracledb):
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', 'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', where_clauses
where_clauses ),
)
) )
assert test_oracledb.delete('mytable', where_clauses) assert test_oracledb.delete("mytable", where_clauses)
def test_delete_just_try(mocker, test_oracledb): def test_delete_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try)
assert test_oracledb.delete('mytable', None, just_try=True) assert test_oracledb.delete("mytable", None, just_try=True)
def test_truncate(mocker, test_oracledb): def test_truncate(mocker, test_oracledb):
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSQL', "mylib.oracle.OracleDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
) )
assert test_oracledb.truncate('mytable') assert test_oracledb.truncate("mytable")
def test_truncate_just_try(mocker, test_oracledb): def test_truncate_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try)
assert test_oracledb.truncate('mytable', just_try=True) assert test_oracledb.truncate("mytable", just_try=True)
def test_select(mocker, test_oracledb): def test_select(mocker, test_oracledb):
fields = ('field1', 'field2') fields = ("field1", "field2")
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
expected_return = [ expected_return = [
dict(field1=1, field2=2), dict(field1=1, field2=2),
@ -344,30 +367,30 @@ def test_select(mocker, test_oracledb):
] ]
order_by = "field1, DESC" order_by = "field1, DESC"
mocker.patch( mocker.patch(
'mylib.oracle.OracleDB.doSelect', "mylib.oracle.OracleDB.doSelect",
generate_mock_doSQL( generate_mock_doSQL(
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4 ORDER BY ' + order_by, 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4'
where_clauses, expected_return " ORDER BY " + order_by,
) where_clauses,
expected_return,
),
) )
assert test_oracledb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return assert (
test_oracledb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
)
def test_select_without_field_and_order_by(mocker, test_oracledb): def test_select_without_field_and_order_by(mocker, test_oracledb):
mocker.patch( mocker.patch("mylib.oracle.OracleDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"'))
'mylib.oracle.OracleDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM "mytable"'
)
)
assert test_oracledb.select('mytable') assert test_oracledb.select("mytable")
def test_select_just_try(mocker, test_oracledb): def test_select_just_try(mocker, test_oracledb):
mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try)
assert test_oracledb.select('mytable', None, None, just_try=True) assert test_oracledb.select("mytable", None, None, just_try=True)
# #
# Tests on main methods # Tests on main methods
@ -376,17 +399,10 @@ def test_select_just_try(mocker, test_oracledb):
def test_connect(mocker, test_oracledb): def test_connect(mocker, test_oracledb):
expected_kwargs = dict( expected_kwargs = dict(
dsn=test_oracledb._dsn, dsn=test_oracledb._dsn, user=test_oracledb._user, password=test_oracledb._pwd
user=test_oracledb._user,
password=test_oracledb._pwd
) )
mocker.patch( mocker.patch("cx_Oracle.connect", generate_mock_args(expected_kwargs=expected_kwargs))
'cx_Oracle.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_oracledb.connect() assert test_oracledb.connect()
@ -400,50 +416,62 @@ def test_close_connected(fake_connected_oracledb):
def test_doSQL(fake_connected_oracledb): def test_doSQL(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table WHERE test1 = :test1' fake_connected_oracledb._conn.expected_sql = "DELETE FROM table WHERE test1 = :test1"
fake_connected_oracledb._conn.expected_params = dict(test1=1) fake_connected_oracledb._conn.expected_params = dict(test1=1)
fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params) fake_connected_oracledb.doSQL(
fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params
)
def test_doSQL_without_params(fake_connected_oracledb): def test_doSQL_without_params(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table' fake_connected_oracledb._conn.expected_sql = "DELETE FROM table"
fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql) fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_oracledb): def test_doSQL_just_try(fake_connected_just_try_oracledb):
assert fake_connected_just_try_oracledb.doSQL('DELETE FROM table') assert fake_connected_just_try_oracledb.doSQL("DELETE FROM table")
def test_doSQL_on_exception(fake_connected_oracledb): def test_doSQL_on_exception(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_exception = True fake_connected_oracledb._conn.expected_exception = True
assert fake_connected_oracledb.doSQL('DELETE FROM table') is False assert fake_connected_oracledb.doSQL("DELETE FROM table") is False
def test_doSelect(fake_connected_oracledb): def test_doSelect(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1' fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1"
fake_connected_oracledb._conn.expected_params = dict(test1=1) fake_connected_oracledb._conn.expected_params = dict(test1=1)
fake_connected_oracledb._conn.expected_return = [dict(test1=1)] fake_connected_oracledb._conn.expected_return = [dict(test1=1)]
assert fake_connected_oracledb.doSelect( assert (
fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb.doSelect(
fake_connected_oracledb._conn.expected_params) == fake_connected_oracledb._conn.expected_return fake_connected_oracledb._conn.expected_sql,
fake_connected_oracledb._conn.expected_params,
)
== fake_connected_oracledb._conn.expected_return
)
def test_doSelect_without_params(fake_connected_oracledb): def test_doSelect_without_params(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table' fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table"
fake_connected_oracledb._conn.expected_return = [dict(test1=1)] fake_connected_oracledb._conn.expected_return = [dict(test1=1)]
assert fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql) == fake_connected_oracledb._conn.expected_return assert (
fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql)
== fake_connected_oracledb._conn.expected_return
)
def test_doSelect_on_exception(fake_connected_oracledb): def test_doSelect_on_exception(fake_connected_oracledb):
fake_connected_oracledb._conn.expected_exception = True fake_connected_oracledb._conn.expected_exception = True
assert fake_connected_oracledb.doSelect('SELECT * FROM table') is False assert fake_connected_oracledb.doSelect("SELECT * FROM table") is False
def test_doSelect_just_try(fake_connected_just_try_oracledb): def test_doSelect_just_try(fake_connected_just_try_oracledb):
fake_connected_just_try_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1' fake_connected_just_try_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1"
fake_connected_just_try_oracledb._conn.expected_params = dict(test1=1) fake_connected_just_try_oracledb._conn.expected_params = dict(test1=1)
fake_connected_just_try_oracledb._conn.expected_return = [dict(test1=1)] fake_connected_just_try_oracledb._conn.expected_return = [dict(test1=1)]
assert fake_connected_just_try_oracledb.doSelect( assert (
fake_connected_just_try_oracledb._conn.expected_sql, fake_connected_just_try_oracledb.doSelect(
fake_connected_just_try_oracledb._conn.expected_params fake_connected_just_try_oracledb._conn.expected_sql,
) == fake_connected_just_try_oracledb._conn.expected_return fake_connected_just_try_oracledb._conn.expected_params,
)
== fake_connected_just_try_oracledb._conn.expected_return
)

View file

@ -8,9 +8,11 @@ from mylib.pgsql import PgDB
class FakePsycopg2Cursor: class FakePsycopg2Cursor:
""" Fake Psycopg2 cursor """ """Fake Psycopg2 cursor"""
def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): def __init__(
self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception
):
self.expected_sql = expected_sql self.expected_sql = expected_sql
self.expected_params = expected_params self.expected_params = expected_params
self.expected_return = expected_return self.expected_return = expected_return
@ -19,13 +21,25 @@ class FakePsycopg2Cursor:
def execute(self, sql, params=None): def execute(self, sql, params=None):
if self.expected_exception: if self.expected_exception:
raise psycopg2.Error(f'{self}.execute({sql}, {params}): expected exception') raise psycopg2.Error(f"{self}.execute({sql}, {params}): expected exception")
if self.expected_just_try and not sql.lower().startswith('select '): if self.expected_just_try and not sql.lower().startswith("select "):
assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode"
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) assert (
sql == self.expected_sql
), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (
self,
sql,
self.expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) assert (
params == self.expected_params
), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (
self,
params,
self.expected_params,
)
return self.expected_return return self.expected_return
def fetchall(self): def fetchall(self):
@ -33,13 +47,13 @@ class FakePsycopg2Cursor:
def __repr__(self): def __repr__(self):
return ( return (
f'FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, ' f"FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, "
f'{self.expected_return}, {self.expected_just_try})' f"{self.expected_return}, {self.expected_just_try})"
) )
class FakePsycopg2: class FakePsycopg2:
""" Fake Psycopg2 connection """ """Fake Psycopg2 connection"""
expected_sql = None expected_sql = None
expected_params = None expected_params = None
@ -52,8 +66,9 @@ class FakePsycopg2:
allowed_kwargs = dict(dbname=str, user=str, password=(str, None), host=str) allowed_kwargs = dict(dbname=str, user=str, password=(str, None), host=str)
for arg, value in kwargs.items(): for arg, value in kwargs.items():
assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"' assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"'
assert isinstance(value, allowed_kwargs[arg]), \ assert isinstance(
f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})' value, allowed_kwargs[arg]
), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})"
setattr(self, arg, value) setattr(self, arg, value)
def close(self): def close(self):
@ -63,14 +78,16 @@ class FakePsycopg2:
self._check_just_try() self._check_just_try()
assert len(arg) == 1 and isinstance(arg[0], str) assert len(arg) == 1 and isinstance(arg[0], str)
if self.expected_exception: if self.expected_exception:
raise psycopg2.Error(f'set_client_encoding({arg[0]}): Expected exception') raise psycopg2.Error(f"set_client_encoding({arg[0]}): Expected exception")
return self.expected_return return self.expected_return
def cursor(self): def cursor(self):
return FakePsycopg2Cursor( return FakePsycopg2Cursor(
self.expected_sql, self.expected_params, self.expected_sql,
self.expected_return, self.expected_just_try or self.just_try, self.expected_params,
self.expected_exception self.expected_return,
self.expected_just_try or self.just_try,
self.expected_exception,
) )
def commit(self): def commit(self):
@ -98,19 +115,19 @@ def fake_psycopg2_connect_just_try(**kwargs):
@pytest.fixture @pytest.fixture
def test_pgdb(): def test_pgdb():
return PgDB('127.0.0.1', 'user', 'password', 'dbname') return PgDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_pgdb(mocker): def fake_pgdb(mocker):
mocker.patch('psycopg2.connect', fake_psycopg2_connect) mocker.patch("psycopg2.connect", fake_psycopg2_connect)
return PgDB('127.0.0.1', 'user', 'password', 'dbname') return PgDB("127.0.0.1", "user", "password", "dbname")
@pytest.fixture @pytest.fixture
def fake_just_try_pgdb(mocker): def fake_just_try_pgdb(mocker):
mocker.patch('psycopg2.connect', fake_psycopg2_connect_just_try) mocker.patch("psycopg2.connect", fake_psycopg2_connect_just_try)
return PgDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True) return PgDB("127.0.0.1", "user", "password", "dbname", just_try=True)
@pytest.fixture @pytest.fixture
@ -125,13 +142,22 @@ def fake_connected_just_try_pgdb(fake_just_try_pgdb):
return fake_just_try_pgdb return fake_just_try_pgdb
def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_args(
expected_args=(), expected_kwargs={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_args(*args, **kwargs): def mock_args(*args, **kwargs):
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (
args,
expected_args,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (
kwargs,
expected_kwargs,
)
return expected_return return expected_return
return mock_args return mock_args
@ -139,13 +165,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum
assert False, "doSQL() may not be executed in just try mode" assert False, "doSQL() may not be executed in just try mode"
def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value def generate_mock_doSQL(
expected_sql, expected_params={}, expected_return=True
): # pylint: disable=dangerous-default-value
def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (
sql,
expected_sql,
)
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (
params,
expected_params,
)
return expected_return return expected_return
return mock_doSQL return mock_doSQL
@ -159,15 +194,11 @@ mock_doSelect_just_try = mock_doSQL_just_try
def test_combine_params_with_to_add_parameter(): def test_combine_params_with_to_add_parameter():
assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict( assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs(): def test_combine_params_with_kargs():
assert PgDB._combine_params(dict(test1=1), test2=2) == dict( assert PgDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2)
test1=1, test2=2
)
def test_combine_params_with_kargs_and_to_add_parameter(): def test_combine_params_with_kargs_and_to_add_parameter():
@ -177,19 +208,16 @@ def test_combine_params_with_kargs_and_to_add_parameter():
def test_format_where_clauses_params_are_preserved(): def test_format_where_clauses_params_are_preserved():
args = ('test = test', dict(test1=1)) args = ("test = test", dict(test1=1))
assert PgDB._format_where_clauses(*args) == args assert PgDB._format_where_clauses(*args) == args
def test_format_where_clauses_raw(): def test_format_where_clauses_raw():
assert PgDB._format_where_clauses('test = test') == (('test = test'), {}) assert PgDB._format_where_clauses("test = test") == ("test = test", {})
def test_format_where_clauses_tuple_clause_with_params(): def test_format_where_clauses_tuple_clause_with_params():
where_clauses = ( where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2))
'test1 = %(test1)s AND test2 = %(test2)s',
dict(test1=1, test2=2)
)
assert PgDB._format_where_clauses(where_clauses) == where_clauses assert PgDB._format_where_clauses(where_clauses) == where_clauses
@ -197,27 +225,23 @@ def test_format_where_clauses_dict():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert PgDB._format_where_clauses(where_clauses) == ( assert PgDB._format_where_clauses(where_clauses) == (
'"test1" = %(test1)s AND "test2" = %(test2)s', '"test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses where_clauses,
) )
def test_format_where_clauses_combined_types(): def test_format_where_clauses_combined_types():
where_clauses = ( where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4))
'test1 = 1',
('test2 LIKE %(test2)s', dict(test2=2)),
dict(test3=3, test4=4)
)
assert PgDB._format_where_clauses(where_clauses) == ( assert PgDB._format_where_clauses(where_clauses) == (
'test1 = 1 AND test2 LIKE %(test2)s AND "test3" = %(test3)s AND "test4" = %(test4)s', 'test1 = 1 AND test2 LIKE %(test2)s AND "test3" = %(test3)s AND "test4" = %(test4)s',
dict(test2=2, test3=3, test4=4) dict(test2=2, test3=3, test4=4),
) )
def test_format_where_clauses_with_where_op(): def test_format_where_clauses_with_where_op():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert PgDB._format_where_clauses(where_clauses, where_op='OR') == ( assert PgDB._format_where_clauses(where_clauses, where_op="OR") == (
'"test1" = %(test1)s OR "test2" = %(test2)s', '"test1" = %(test1)s OR "test2" = %(test2)s',
where_clauses where_clauses,
) )
@ -226,7 +250,7 @@ def test_add_where_clauses():
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
assert PgDB._add_where_clauses(sql, None, where_clauses) == ( assert PgDB._add_where_clauses(sql, None, where_clauses) == (
sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s', sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
where_clauses where_clauses,
) )
@ -236,26 +260,26 @@ def test_add_where_clauses_preserved_params():
params = dict(fake1=1) params = dict(fake1=1)
assert PgDB._add_where_clauses(sql, params.copy(), where_clauses) == ( assert PgDB._add_where_clauses(sql, params.copy(), where_clauses) == (
sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s', sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s',
dict(**where_clauses, **params) dict(**where_clauses, **params),
) )
def test_add_where_clauses_with_op(): def test_add_where_clauses_with_op():
sql = "SELECT * FROM table" sql = "SELECT * FROM table"
where_clauses = ('test1=1', 'test2=2') where_clauses = ("test1=1", "test2=2")
assert PgDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( assert PgDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == (
sql + ' WHERE test1=1 OR test2=2', sql + " WHERE test1=1 OR test2=2",
{} {},
) )
def test_add_where_clauses_with_duplicated_field(): def test_add_where_clauses_with_duplicated_field():
sql = "UPDATE table SET test1=%(test1)s" sql = "UPDATE table SET test1=%(test1)s"
params = dict(test1='new_value') params = dict(test1="new_value")
where_clauses = dict(test1='where_value') where_clauses = dict(test1="where_value")
assert PgDB._add_where_clauses(sql, params, where_clauses) == ( assert PgDB._add_where_clauses(sql, params, where_clauses) == (
sql + ' WHERE "test1" = %(test1_1)s', sql + ' WHERE "test1" = %(test1_1)s',
dict(test1='new_value', test1_1='where_value') dict(test1="new_value", test1_1="where_value"),
) )
@ -267,74 +291,70 @@ def test_quote_table_name():
def test_insert(mocker, test_pgdb): def test_insert(mocker, test_pgdb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSQL', "mylib.pgsql.PgDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', 'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', values
values ),
)
) )
assert test_pgdb.insert('mytable', values) assert test_pgdb.insert("mytable", values)
def test_insert_just_try(mocker, test_pgdb): def test_insert_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.insert('mytable', dict(test1=1, test2=2), just_try=True) assert test_pgdb.insert("mytable", dict(test1=1, test2=2), just_try=True)
def test_update(mocker, test_pgdb): def test_update(mocker, test_pgdb):
values = dict(test1=1, test2=2) values = dict(test1=1, test2=2)
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSQL', "mylib.pgsql.PgDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" = %(test3)s AND "test4" = %(test4)s', 'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" ='
dict(**values, **where_clauses) ' %(test3)s AND "test4" = %(test4)s',
) dict(**values, **where_clauses),
),
) )
assert test_pgdb.update('mytable', values, where_clauses) assert test_pgdb.update("mytable", values, where_clauses)
def test_update_just_try(mocker, test_pgdb): def test_update_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.update('mytable', dict(test1=1, test2=2), None, just_try=True) assert test_pgdb.update("mytable", dict(test1=1, test2=2), None, just_try=True)
def test_delete(mocker, test_pgdb): def test_delete(mocker, test_pgdb):
where_clauses = dict(test1=1, test2=2) where_clauses = dict(test1=1, test2=2)
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSQL', "mylib.pgsql.PgDB.doSQL",
generate_mock_doSQL( generate_mock_doSQL(
'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', 'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', where_clauses
where_clauses ),
)
) )
assert test_pgdb.delete('mytable', where_clauses) assert test_pgdb.delete("mytable", where_clauses)
def test_delete_just_try(mocker, test_pgdb): def test_delete_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try)
assert test_pgdb.delete('mytable', None, just_try=True) assert test_pgdb.delete("mytable", None, just_try=True)
def test_truncate(mocker, test_pgdb): def test_truncate(mocker, test_pgdb):
mocker.patch( mocker.patch("mylib.pgsql.PgDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None))
'mylib.pgsql.PgDB.doSQL',
generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)
)
assert test_pgdb.truncate('mytable') assert test_pgdb.truncate("mytable")
def test_truncate_just_try(mocker, test_pgdb): def test_truncate_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try)
assert test_pgdb.truncate('mytable', just_try=True) assert test_pgdb.truncate("mytable", just_try=True)
def test_select(mocker, test_pgdb): def test_select(mocker, test_pgdb):
fields = ('field1', 'field2') fields = ("field1", "field2")
where_clauses = dict(test3=3, test4=4) where_clauses = dict(test3=3, test4=4)
expected_return = [ expected_return = [
dict(field1=1, field2=2), dict(field1=1, field2=2),
@ -342,30 +362,28 @@ def test_select(mocker, test_pgdb):
] ]
order_by = "field1, DESC" order_by = "field1, DESC"
mocker.patch( mocker.patch(
'mylib.pgsql.PgDB.doSelect', "mylib.pgsql.PgDB.doSelect",
generate_mock_doSQL( generate_mock_doSQL(
'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" = %(test4)s ORDER BY ' + order_by, 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" ='
where_clauses, expected_return " %(test4)s ORDER BY " + order_by,
) where_clauses,
expected_return,
),
) )
assert test_pgdb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return assert test_pgdb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return
def test_select_without_field_and_order_by(mocker, test_pgdb): def test_select_without_field_and_order_by(mocker, test_pgdb):
mocker.patch( mocker.patch("mylib.pgsql.PgDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"'))
'mylib.pgsql.PgDB.doSelect',
generate_mock_doSQL(
'SELECT * FROM "mytable"'
)
)
assert test_pgdb.select('mytable') assert test_pgdb.select("mytable")
def test_select_just_try(mocker, test_pgdb): def test_select_just_try(mocker, test_pgdb):
mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try) mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try)
assert test_pgdb.select('mytable', None, None, just_try=True) assert test_pgdb.select("mytable", None, None, just_try=True)
# #
# Tests on main methods # Tests on main methods
@ -374,18 +392,10 @@ def test_select_just_try(mocker, test_pgdb):
def test_connect(mocker, test_pgdb): def test_connect(mocker, test_pgdb):
expected_kwargs = dict( expected_kwargs = dict(
dbname=test_pgdb._db, dbname=test_pgdb._db, user=test_pgdb._user, host=test_pgdb._host, password=test_pgdb._pwd
user=test_pgdb._user,
host=test_pgdb._host,
password=test_pgdb._pwd
) )
mocker.patch( mocker.patch("psycopg2.connect", generate_mock_args(expected_kwargs=expected_kwargs))
'psycopg2.connect',
generate_mock_args(
expected_kwargs=expected_kwargs
)
)
assert test_pgdb.connect() assert test_pgdb.connect()
@ -399,61 +409,74 @@ def test_close_connected(fake_connected_pgdb):
def test_setEncoding(fake_connected_pgdb): def test_setEncoding(fake_connected_pgdb):
assert fake_connected_pgdb.setEncoding('utf8') assert fake_connected_pgdb.setEncoding("utf8")
def test_setEncoding_not_connected(fake_pgdb): def test_setEncoding_not_connected(fake_pgdb):
assert fake_pgdb.setEncoding('utf8') is False assert fake_pgdb.setEncoding("utf8") is False
def test_setEncoding_on_exception(fake_connected_pgdb): def test_setEncoding_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.setEncoding('utf8') is False assert fake_connected_pgdb.setEncoding("utf8") is False
def test_doSQL(fake_connected_pgdb): def test_doSQL(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s' fake_connected_pgdb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s"
fake_connected_pgdb._conn.expected_params = dict(test1=1) fake_connected_pgdb._conn.expected_params = dict(test1=1)
fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) fake_connected_pgdb.doSQL(
fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params
)
def test_doSQL_without_params(fake_connected_pgdb): def test_doSQL_without_params(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table' fake_connected_pgdb._conn.expected_sql = "DELETE FROM table"
fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql) fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql)
def test_doSQL_just_try(fake_connected_just_try_pgdb): def test_doSQL_just_try(fake_connected_just_try_pgdb):
assert fake_connected_just_try_pgdb.doSQL('DELETE FROM table') assert fake_connected_just_try_pgdb.doSQL("DELETE FROM table")
def test_doSQL_on_exception(fake_connected_pgdb): def test_doSQL_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.doSQL('DELETE FROM table') is False assert fake_connected_pgdb.doSQL("DELETE FROM table") is False
def test_doSelect(fake_connected_pgdb): def test_doSelect(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_pgdb._conn.expected_params = dict(test1=1) fake_connected_pgdb._conn.expected_params = dict(test1=1)
fake_connected_pgdb._conn.expected_return = [dict(test1=1)] fake_connected_pgdb._conn.expected_return = [dict(test1=1)]
assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) == fake_connected_pgdb._conn.expected_return assert (
fake_connected_pgdb.doSelect(
fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params
)
== fake_connected_pgdb._conn.expected_return
)
def test_doSelect_without_params(fake_connected_pgdb): def test_doSelect_without_params(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table' fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table"
fake_connected_pgdb._conn.expected_return = [dict(test1=1)] fake_connected_pgdb._conn.expected_return = [dict(test1=1)]
assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql) == fake_connected_pgdb._conn.expected_return assert (
fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql)
== fake_connected_pgdb._conn.expected_return
)
def test_doSelect_on_exception(fake_connected_pgdb): def test_doSelect_on_exception(fake_connected_pgdb):
fake_connected_pgdb._conn.expected_exception = True fake_connected_pgdb._conn.expected_exception = True
assert fake_connected_pgdb.doSelect('SELECT * FROM table') is False assert fake_connected_pgdb.doSelect("SELECT * FROM table") is False
def test_doSelect_just_try(fake_connected_just_try_pgdb): def test_doSelect_just_try(fake_connected_just_try_pgdb):
fake_connected_just_try_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' fake_connected_just_try_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s"
fake_connected_just_try_pgdb._conn.expected_params = dict(test1=1) fake_connected_just_try_pgdb._conn.expected_params = dict(test1=1)
fake_connected_just_try_pgdb._conn.expected_return = [dict(test1=1)] fake_connected_just_try_pgdb._conn.expected_return = [dict(test1=1)]
assert fake_connected_just_try_pgdb.doSelect( assert (
fake_connected_just_try_pgdb._conn.expected_sql, fake_connected_just_try_pgdb.doSelect(
fake_connected_just_try_pgdb._conn.expected_params fake_connected_just_try_pgdb._conn.expected_sql,
) == fake_connected_just_try_pgdb._conn.expected_return fake_connected_just_try_pgdb._conn.expected_params,
)
== fake_connected_just_try_pgdb._conn.expected_return
)

View file

@ -3,13 +3,14 @@
import datetime import datetime
import os import os
import pytest import pytest
from mylib.telltale import TelltaleFile from mylib.telltale import TelltaleFile
def test_create_telltale_file(tmp_path): def test_create_telltale_file(tmp_path):
filename = 'test' filename = "test"
file = TelltaleFile(filename=filename, dirpath=tmp_path) file = TelltaleFile(filename=filename, dirpath=tmp_path)
assert file.filename == filename assert file.filename == filename
assert file.dirpath == tmp_path assert file.dirpath == tmp_path
@ -24,15 +25,15 @@ def test_create_telltale_file(tmp_path):
def test_create_telltale_file_with_filepath_and_invalid_dirpath(): def test_create_telltale_file_with_filepath_and_invalid_dirpath():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
TelltaleFile(filepath='/tmp/test', dirpath='/var/tmp') TelltaleFile(filepath="/tmp/test", dirpath="/var/tmp")
def test_create_telltale_file_with_filepath_and_invalid_filename(): def test_create_telltale_file_with_filepath_and_invalid_filename():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
TelltaleFile(filepath='/tmp/test', filename='other') TelltaleFile(filepath="/tmp/test", filename="other")
def test_remove_telltale_file(tmp_path): def test_remove_telltale_file(tmp_path):
file = TelltaleFile(filename='test', dirpath=tmp_path) file = TelltaleFile(filename="test", dirpath=tmp_path)
file.update() file.update()
assert file.remove() assert file.remove()