""" MySQL client """ import logging import sys import MySQLdb from MySQLdb._exceptions import Error from mylib.db import DB, DBFailToConnect log = logging.getLogger(__name__) class MyDB(DB): """MySQL client""" _host = None _user = None _pwd = None _db = None def __init__(self, host, user, pwd, db, charset=None, **kwargs): self._host = host self._user = user self._pwd = pwd self._db = db self._charset = charset if charset else "utf8" super().__init__(**kwargs) def connect(self, exit_on_error=True): """Connect to MySQL server""" if self._conn is None: try: self._conn = MySQLdb.connect( host=self._host, user=self._user, passwd=self._pwd, db=self._db, charset=self._charset, use_unicode=True, ) except Error as err: log.fatal( "An error occurred during MySQL database connection (%s@%s:%s).", self._user, self._host, self._db, exc_info=1, ) if exit_on_error: sys.exit(1) else: raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err return True def doSQL(self, sql, params=None): """ Run SQL query and commit changes (rollback on error) :param sql: The SQL query :param params: The SQL query's parameters as dict (optional) :return: True on success, False otherwise :rtype: bool """ if self.just_try: log.debug("Just-try mode : do not really execute SQL query '%s'", sql) return True cursor = self._conn.cursor() try: self._log_query(sql, params) cursor.execute(sql, params) self._conn.commit() return True except Error: self._log_query_exception(sql, params) self._conn.rollback() return False def doSelect(self, sql, params=None): """ Run SELECT SQL query and return list of selected rows as dict :param sql: The SQL query :param params: The SQL query's parameters as dict (optional) :return: List of selected rows as dict on success, False otherwise :rtype: list, bool """ try: self._log_query(sql, params) cursor = self._conn.cursor() cursor.execute(sql, params) return [ {field[0]: row[idx] for idx, field in enumerate(cursor.description)} for row in cursor.fetchall() ] except Error: self._log_query_exception(sql, params) return False @staticmethod def _quote_table_name(table): """Quote table name""" return "`{}`".format( # pylint: disable=consider-using-f-string "`.`".join(table.split(".")) ) @staticmethod def _quote_field_name(field): """Quote table name""" return f"`{field}`"