113 lines
3.2 KiB
Python
113 lines
3.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
""" MySQL client """
|
|
|
|
import logging
|
|
import sys
|
|
|
|
import MySQLdb
|
|
from MySQLdb._exceptions import Error
|
|
|
|
from mylib.db import DB
|
|
from mylib.db import DBFailToConnect
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MyDB(DB):
|
|
""" MySQL client """
|
|
|
|
_host = None
|
|
_user = None
|
|
_pwd = None
|
|
_db = None
|
|
|
|
def __init__(self, host, user, pwd, db, charset=None, **kwargs):
|
|
self._host = host
|
|
self._user = user
|
|
self._pwd = pwd
|
|
self._db = db
|
|
self._charset = charset if charset else 'utf8'
|
|
super().__init__(**kwargs)
|
|
|
|
def connect(self, exit_on_error=True):
|
|
""" Connect to MySQL server """
|
|
if self._conn is None:
|
|
try:
|
|
self._conn = MySQLdb.connect(
|
|
host=self._host, user=self._user, passwd=self._pwd,
|
|
db=self._db, charset=self._charset, use_unicode=True)
|
|
except Error as err:
|
|
log.fatal(
|
|
'An error occured during MySQL database connection (%s@%s:%s).',
|
|
self._user, self._host, self._db, exc_info=1
|
|
)
|
|
if exit_on_error:
|
|
sys.exit(1)
|
|
else:
|
|
raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err
|
|
return True
|
|
|
|
def doSQL(self, sql, params=None):
|
|
"""
|
|
Run SQL query and commit changes (rollback on error)
|
|
|
|
:param sql: The SQL query
|
|
:param params: The SQL query's parameters as dict (optional)
|
|
|
|
:return: True on success, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
if self.just_try:
|
|
log.debug("Just-try mode : do not really execute SQL query '%s'", sql)
|
|
return True
|
|
cursor = self._conn.cursor()
|
|
try:
|
|
self._log_query(sql, params)
|
|
cursor.execute(sql, params)
|
|
self._conn.commit()
|
|
return True
|
|
except Error:
|
|
self._log_query_exception(sql, params)
|
|
self._conn.rollback()
|
|
return False
|
|
|
|
def doSelect(self, sql, params=None):
|
|
"""
|
|
Run SELECT SQL query and return list of selected rows as dict
|
|
|
|
:param sql: The SQL query
|
|
:param params: The SQL query's parameters as dict (optional)
|
|
|
|
:return: List of selected rows as dict on success, False otherwise
|
|
:rtype: list, bool
|
|
"""
|
|
try:
|
|
self._log_query(sql, params)
|
|
cursor = self._conn.cursor()
|
|
cursor.execute(sql, params)
|
|
return [
|
|
dict(
|
|
(field[0], row[idx])
|
|
for idx, field in enumerate(cursor.description)
|
|
)
|
|
for row in cursor.fetchall()
|
|
]
|
|
except Error:
|
|
self._log_query_exception(sql, params)
|
|
return False
|
|
|
|
@staticmethod
|
|
def _quote_table_name(table):
|
|
""" Quote table name """
|
|
return '`{0}`'.format( # pylint: disable=consider-using-f-string
|
|
'`.`'.join(
|
|
table.split('.')
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def _quote_field_name(field):
|
|
""" Quote table name """
|
|
return f'`{field}`'
|