diff --git a/mylib/db.py b/mylib/db.py index d0dc3c7..efc8001 100644 --- a/mylib/db.py +++ b/mylib/db.py @@ -192,6 +192,17 @@ class DB: params[param] = value return params + @staticmethod + def _get_unique_param_name(field, params): + """Return a unique parameter name based on specified field name""" + param = field + if field in params: + idx = 1 + while param in params: + param = f"{field}_{idx}" + idx += 1 + return param + @classmethod def _format_where_clauses(cls, where_clauses, params=None, where_op=None): """ @@ -237,16 +248,22 @@ class DB: if isinstance(where_clauses, dict): sql_where_clauses = [] for field, value in where_clauses.items(): - param = field - if field in params: - idx = 1 - while param in params: - param = f"{field}_{idx}" - idx += 1 - cls._combine_params(params, {param: value}) - sql_where_clauses.append( - f"{cls._quote_field_name(field)} = {cls.format_param(param)}" - ) + if isinstance(value, list): + param_names = [] + for idx, v in enumerate(value): + param = cls._get_unique_param_name(f"{field}_{idx}", params) + cls._combine_params(params, **{param: v}) + param_names.append(param) + sql_where_clauses.append( + f"{cls._quote_field_name(field)} IN " + f"({', '.join([cls.format_param(param) for param in param_names])})" + ) + else: + param = cls._get_unique_param_name(field, params) + cls._combine_params(params, **{param: value}) + sql_where_clauses.append( + f"{cls._quote_field_name(field)} = {cls.format_param(param)}" + ) return (f" {where_op} ".join(sql_where_clauses), params) raise DBUnsupportedWHEREClauses(where_clauses) diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 2cd17db..14c2dc4 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -234,6 +234,13 @@ def test_format_where_clauses_tuple_clause_with_params(): assert MyDB._format_where_clauses(where_clauses) == where_clauses +def test_format_where_clauses_with_list_as_value(): + assert MyDB._format_where_clauses({"test": [1, 2]}) == ( + "`test` IN (%(test_0)s, %(test_1)s)", + {"test_0": 1, "test_1": 2}, + ) + + def test_format_where_clauses_dict(): where_clauses = {"test1": 1, "test2": 2} assert MyDB._format_where_clauses(where_clauses) == ( diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 99081ec..fdf8bc6 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -226,6 +226,13 @@ def test_format_where_clauses_tuple_clause_with_params(): assert OracleDB._format_where_clauses(where_clauses) == where_clauses +def test_format_where_clauses_with_list_as_value(): + assert OracleDB._format_where_clauses({"test": [1, 2]}) == ( + '"test" IN (:test_0, :test_1)', + {"test_0": 1, "test_1": 2}, + ) + + def test_format_where_clauses_dict(): where_clauses = {"test1": 1, "test2": 2} assert OracleDB._format_where_clauses(where_clauses) == ( diff --git a/tests/test_pgsql.py b/tests/test_pgsql.py index d31cd4f..dabdc80 100644 --- a/tests/test_pgsql.py +++ b/tests/test_pgsql.py @@ -226,6 +226,13 @@ def test_format_where_clauses_tuple_clause_with_params(): assert PgDB._format_where_clauses(where_clauses) == where_clauses +def test_format_where_clauses_with_list_as_value(): + assert PgDB._format_where_clauses({"test": [1, 2]}) == ( + '"test" IN (%(test_0)s, %(test_1)s)', + {"test_0": 1, "test_1": 2}, + ) + + def test_format_where_clauses_dict(): where_clauses = {"test1": 1, "test2": 2} assert PgDB._format_where_clauses(where_clauses) == (