diff --git a/mylib/db.py b/mylib/db.py index d0dc3c7..efc8001 100644 --- a/mylib/db.py +++ b/mylib/db.py @@ -192,6 +192,17 @@ class DB: params[param] = value return params + @staticmethod + def _get_unique_param_name(field, params): + """Return a unique parameter name based on specified field name""" + param = field + if field in params: + idx = 1 + while param in params: + param = f"{field}_{idx}" + idx += 1 + return param + @classmethod def _format_where_clauses(cls, where_clauses, params=None, where_op=None): """ @@ -237,16 +248,22 @@ class DB: if isinstance(where_clauses, dict): sql_where_clauses = [] for field, value in where_clauses.items(): - param = field - if field in params: - idx = 1 - while param in params: - param = f"{field}_{idx}" - idx += 1 - cls._combine_params(params, {param: value}) - sql_where_clauses.append( - f"{cls._quote_field_name(field)} = {cls.format_param(param)}" - ) + if isinstance(value, list): + param_names = [] + for idx, v in enumerate(value): + param = cls._get_unique_param_name(f"{field}_{idx}", params) + cls._combine_params(params, **{param: v}) + param_names.append(param) + sql_where_clauses.append( + f"{cls._quote_field_name(field)} IN " + f"({', '.join([cls.format_param(param) for param in param_names])})" + ) + else: + param = cls._get_unique_param_name(field, params) + cls._combine_params(params, **{param: value}) + sql_where_clauses.append( + f"{cls._quote_field_name(field)} = {cls.format_param(param)}" + ) return (f" {where_op} ".join(sql_where_clauses), params) raise DBUnsupportedWHEREClauses(where_clauses)