diff --git a/mylib/oracle.py b/mylib/oracle.py index 0f30411..2599253 100644 --- a/mylib/oracle.py +++ b/mylib/oracle.py @@ -125,7 +125,6 @@ class OracleDB: log.debug("Just-try mode : do not really execute SQL query '%s'", sql) return True - cursor = self._conn.cursor() try: log.debug( 'Run SQL query "%s" %s', @@ -135,10 +134,11 @@ class OracleDB: for key, value in params.items() ]) if params else "without params" ) - if isinstance(params, dict): - cursor.execute(sql, **params) - else: - cursor.execute(sql) + with self._conn.cursor() as cursor: + if isinstance(params, dict): + cursor.execute(sql, **params) + else: + cursor.execute(sql) self._conn.commit() return True except Exception: @@ -164,7 +164,6 @@ class OracleDB: :return: List of selected rows as dict on success, False otherwise :rtype: list, bool """ - cursor = self._conn.cursor() try: log.debug( 'Run SQL SELECT query "%s" %s', @@ -174,12 +173,15 @@ class OracleDB: for key, value in params.items() ]) if params else "without params" ) - if isinstance(params, dict): - cursor.execute(sql, **params) - else: - cursor.execute(sql) - cursor.rowfactory = lambda *args: dict(zip([d[0] for d in cursor.description], args)) - results = cursor.fetchall() + with self._conn.cursor() as cursor: + if isinstance(params, dict): + cursor.execute(sql, **params) + else: + cursor.execute(sql) + cursor.rowfactory = lambda *args: dict( + zip([d[0] for d in cursor.description], args) + ) + results = cursor.fetchall() return results except Exception: log.error( diff --git a/tests/test_oracle.py b/tests/test_oracle.py index e810dab..a1169e9 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -15,8 +15,10 @@ class FakeCXOracleCursor: self.expected_return = expected_return self.expected_just_try = expected_just_try self.expected_exception = expected_exception + self.opened = True def execute(self, sql, **params): + assert self.opened if self.expected_exception: raise Exception("%s.execute(%s, %s): expected exception" % (self, sql, params)) if self.expected_just_try and not sql.lower().startswith('select '): @@ -26,8 +28,16 @@ class FakeCXOracleCursor: return self.expected_return def fetchall(self): + assert self.opened return self.expected_return + def __enter__(self): + self.opened = True + return self + + def __exit__(self, *args): + self.opened = False + def __repr__(self): return "FakeCXOracleCursor(%s, %s, %s, %s)" % ( self.expected_sql, self.expected_params,