Db: allow to specify a list as a WHERE clause value and they will be combined as a IN clause
Some checks failed
Run tests / tests (push) Failing after 1m22s

This commit is contained in:
Benjamin Renard 2024-10-14 18:49:57 +02:00
parent 601857a2d4
commit 6ccea48b23
Signed by: bn8
GPG key ID: 3E2E1CE1907115BC
4 changed files with 48 additions and 10 deletions

View file

@ -192,6 +192,17 @@ class DB:
params[param] = value params[param] = value
return params return params
@staticmethod
def _get_unique_param_name(field, params):
"""Return a unique parameter name based on specified field name"""
param = field
if field in params:
idx = 1
while param in params:
param = f"{field}_{idx}"
idx += 1
return param
@classmethod @classmethod
def _format_where_clauses(cls, where_clauses, params=None, where_op=None): def _format_where_clauses(cls, where_clauses, params=None, where_op=None):
""" """
@ -237,16 +248,22 @@ class DB:
if isinstance(where_clauses, dict): if isinstance(where_clauses, dict):
sql_where_clauses = [] sql_where_clauses = []
for field, value in where_clauses.items(): for field, value in where_clauses.items():
param = field if isinstance(value, list):
if field in params: param_names = []
idx = 1 for idx, v in enumerate(value):
while param in params: param = cls._get_unique_param_name(f"{field}_{idx}", params)
param = f"{field}_{idx}" cls._combine_params(params, **{param: v})
idx += 1 param_names.append(param)
cls._combine_params(params, {param: value}) sql_where_clauses.append(
sql_where_clauses.append( f"{cls._quote_field_name(field)} IN "
f"{cls._quote_field_name(field)} = {cls.format_param(param)}" f"({', '.join([cls.format_param(param) for param in param_names])})"
) )
else:
param = cls._get_unique_param_name(field, params)
cls._combine_params(params, **{param: value})
sql_where_clauses.append(
f"{cls._quote_field_name(field)} = {cls.format_param(param)}"
)
return (f" {where_op} ".join(sql_where_clauses), params) return (f" {where_op} ".join(sql_where_clauses), params)
raise DBUnsupportedWHEREClauses(where_clauses) raise DBUnsupportedWHEREClauses(where_clauses)

View file

@ -234,6 +234,13 @@ def test_format_where_clauses_tuple_clause_with_params():
assert MyDB._format_where_clauses(where_clauses) == where_clauses assert MyDB._format_where_clauses(where_clauses) == where_clauses
def test_format_where_clauses_with_list_as_value():
assert MyDB._format_where_clauses({"test": [1, 2]}) == (
"`test` IN (%(test_0)s, %(test_1)s)",
{"test_0": 1, "test_1": 2},
)
def test_format_where_clauses_dict(): def test_format_where_clauses_dict():
where_clauses = {"test1": 1, "test2": 2} where_clauses = {"test1": 1, "test2": 2}
assert MyDB._format_where_clauses(where_clauses) == ( assert MyDB._format_where_clauses(where_clauses) == (

View file

@ -226,6 +226,13 @@ def test_format_where_clauses_tuple_clause_with_params():
assert OracleDB._format_where_clauses(where_clauses) == where_clauses assert OracleDB._format_where_clauses(where_clauses) == where_clauses
def test_format_where_clauses_with_list_as_value():
assert OracleDB._format_where_clauses({"test": [1, 2]}) == (
'"test" IN (:test_0, :test_1)',
{"test_0": 1, "test_1": 2},
)
def test_format_where_clauses_dict(): def test_format_where_clauses_dict():
where_clauses = {"test1": 1, "test2": 2} where_clauses = {"test1": 1, "test2": 2}
assert OracleDB._format_where_clauses(where_clauses) == ( assert OracleDB._format_where_clauses(where_clauses) == (

View file

@ -226,6 +226,13 @@ def test_format_where_clauses_tuple_clause_with_params():
assert PgDB._format_where_clauses(where_clauses) == where_clauses assert PgDB._format_where_clauses(where_clauses) == where_clauses
def test_format_where_clauses_with_list_as_value():
assert PgDB._format_where_clauses({"test": [1, 2]}) == (
'"test" IN (%(test_0)s, %(test_1)s)',
{"test_0": 1, "test_1": 2},
)
def test_format_where_clauses_dict(): def test_format_where_clauses_dict():
where_clauses = {"test1": 1, "test2": 2} where_clauses = {"test1": 1, "test2": 2}
assert PgDB._format_where_clauses(where_clauses) == ( assert PgDB._format_where_clauses(where_clauses) == (