Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions Lib/test/test_sqlite3/test_userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,33 @@ def value(self): return 1 << 65
self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
self.cur.execute, self.query % "err_val_ret")

def test_close_conn_in_window_func_value(self):
# gh-145040: closing connection in window function value() callback.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
con.executemany("INSERT INTO t VALUES(?)",
[(i,) for i in range(20)])

class CloseConnWindow:
def step(self, value):
pass
def finalize(self):
return 0
def value(self):
con.close()
return 0
def inverse(self, value):
pass

con.create_window_function("evil_win", 1, CloseConnWindow)
with self.assertRaises(sqlite.OperationalError):
cursor = con.execute(
"SELECT evil_win(x) OVER "
"(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM t"
)
list(cursor)
con.close()


class AggregateTests(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -723,6 +750,140 @@ def test_agg_keyword_args(self):
'takes exactly 3 positional arguments'):
self.con.create_aggregate("test", 1, aggregate_class=AggrText)

def test_aggr_close_conn_in_step(self):
# Connection.close() in an aggregate step callback must not crash.
con = sqlite.connect(":memory:", autocommit=True)
cur = con.cursor()
cur.execute("CREATE TABLE t(x INTEGER)")
for i in range(50):
cur.execute("INSERT INTO t VALUES (?)", (i,))

class CloseConnAgg:
def __init__(self):
self.total = 0

def step(self, value):
self.total += value
con.close()

def finalize(self):
return self.total

con.create_aggregate("agg_close", 1, CloseConnAgg)
with self.assertRaises(sqlite.OperationalError):
con.execute("SELECT agg_close(x) FROM t")
con.close()

def test_close_conn_in_nested_callback(self):
# gh-145040: close() must be prevented even in nested callbacks.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
for i in range(5):
con.execute("INSERT INTO t VALUES(?)", (i,))

def outer_func(x):
con.close()
return x

def inner_func(x):
return x * 10

con.create_function("outer_func", 1, outer_func)
con.create_function("inner_func", 1, inner_func)
with self.assertRaises(sqlite.OperationalError):
con.execute("SELECT outer_func(inner_func(x)) FROM t")
# Connection must still be usable after the failed close attempt.
self.assertEqual(con.execute("SELECT 1").fetchone(), (1,))
con.close()

def test_close_conn_in_nested_callback_caught(self):
# gh-145040: if the ProgrammingError from close() is caught inside
# the callback, execution continues normally.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
con.execute("INSERT INTO t VALUES(1)")

def swallow_close(x):
try:
con.close()
except sqlite.ProgrammingError:
pass
return x

con.create_function("swallow_close", 1, swallow_close)
# The close() was prevented and the exception was caught,
# so the execute should succeed.
result = con.execute("SELECT swallow_close(x) FROM t").fetchone()
self.assertEqual(result, (1,))
# Connection must still be usable.
self.assertEqual(con.execute("SELECT 1").fetchone(), (1,))
con.close()

def test_close_conn_in_udf_during_executemany(self):
# gh-145040: closing connection in UDF during executemany.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x)")

def close_conn(x):
con.close()
return x

con.create_function("close_conn", 1, close_conn)
with self.assertRaises(sqlite.OperationalError):
con.executemany("INSERT INTO t VALUES(close_conn(?))",
[(i,) for i in range(10)])
con.close()

def test_close_conn_in_progress_handler_during_iternext(self):
# gh-145040: closing connection in progress handler during iteration.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x)")
con.executemany("INSERT INTO t VALUES(?)",
[(i,) for i in range(100)])

count = 0
def close_progress():
nonlocal count
count += 1
if count >= 5:
con.close()
return 1
return 0

cursor = con.execute("SELECT * FROM t")
con.set_progress_handler(close_progress, 1)
with self.assertRaises(sqlite.OperationalError):
for row in cursor:
pass
con.close()

def test_close_conn_in_collation_callback(self):
# gh-145040: closing connection in collation callback.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x TEXT)")
con.executemany("INSERT INTO t VALUES(?)",
[(f"item_{i}",) for i in range(50)])

count = 0
def evil_collation(a, b):
nonlocal count
count += 1
if count == 10:
con.close()
if a < b:
return -1
elif a > b:
return 1
return 0

con.create_collation("evil_coll", evil_collation)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute(
"SELECT * FROM t ORDER BY x COLLATE evil_coll"
)
con.close()


class AuthorizerTests(unittest.TestCase):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fixed a crash in the :mod:`sqlite3` module caused by closing the database
connection from within a callback function invoked during
``sqlite3_step()`` (e.g., an aggregate ``step``, a user-defined function
via :meth:`~sqlite3.Connection.create_function`, a progress handler, or a
collation callback). Raise :exc:`~sqlite3.ProgrammingError` instead of
crashing.
34 changes: 33 additions & 1 deletion Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,17 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, PyObject *database,
goto error;
}

/* Create lists of weak references to blobs */
/* Create lists of weak references to cursors and blobs */
PyObject *cursors = PyList_New(0);
if (cursors == NULL) {
Py_DECREF(statement_cache);
goto error;
}

PyObject *blobs = PyList_New(0);
if (blobs == NULL) {
Py_DECREF(statement_cache);
Py_DECREF(cursors);
goto error;
}

Expand All @@ -299,6 +306,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, PyObject *database,
self->check_same_thread = check_same_thread;
self->thread_ident = PyThread_get_thread_ident();
self->statement_cache = statement_cache;
self->cursors = cursors;
self->blobs = blobs;
self->row_factory = Py_NewRef(Py_None);
self->text_factory = Py_NewRef(&PyUnicode_Type);
Expand Down Expand Up @@ -381,6 +389,7 @@ connection_traverse(PyObject *op, visitproc visit, void *arg)
pysqlite_Connection *self = _pysqlite_Connection_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->statement_cache);
Py_VISIT(self->cursors);
Py_VISIT(self->blobs);
Py_VISIT(self->row_factory);
Py_VISIT(self->text_factory);
Expand All @@ -405,6 +414,7 @@ connection_clear(PyObject *op)
{
pysqlite_Connection *self = _pysqlite_Connection_CAST(op);
Py_CLEAR(self->statement_cache);
Py_CLEAR(self->cursors);
Py_CLEAR(self->blobs);
Py_CLEAR(self->row_factory);
Py_CLEAR(self->text_factory);
Expand Down Expand Up @@ -655,6 +665,28 @@ pysqlite_connection_close_impl(pysqlite_Connection *self)
return NULL;
}

/* Check if any cursor is locked (actively executing a query);
* closing during a callback is illegal per the SQLite C API docs. */
assert(PyList_CheckExact(self->cursors));
Py_ssize_t n = PyList_GET_SIZE(self->cursors);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *weakref = PyList_GET_ITEM(self->cursors, i);
PyObject *obj;
if (!PyWeakref_GetRef(weakref, &obj)) {
continue;
}
int locked = ((pysqlite_Cursor *)obj)->locked;
Py_DECREF(obj);
if (locked) {
PyTypeObject *tp = Py_TYPE(self);
pysqlite_state *state = pysqlite_get_state_by_type(tp);
PyErr_SetString(state->ProgrammingError,
"Cannot close the database connection "
"from within a callback function.");
return NULL;
}
}

pysqlite_close_all_blobs(self);
Py_CLEAR(self->statement_cache);
if (connection_close(self) < 0) {
Expand Down
3 changes: 2 additions & 1 deletion Modules/_sqlite/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ typedef struct

PyObject *statement_cache;

/* Lists of weak references to blobs used within this connection */
/* Lists of weak references to cursors and blobs used within this connection */
PyObject *cursors;
PyObject *blobs;

PyObject* row_factory;
Expand Down
29 changes: 28 additions & 1 deletion Modules/_sqlite/cursor.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ class _sqlite3.Cursor "pysqlite_Cursor *" "clinic_state()->CursorType"
[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=3c5b8115c5cf30f1]*/

/*
* Registers a cursor with the connection.
*
* 0 => error; 1 => ok
*/
static int
register_cursor(pysqlite_Connection *connection, PyObject *cursor)
{
PyObject *weakref = PyWeakref_NewRef(cursor, NULL);
if (weakref == NULL) {
return 0;
}

if (PyList_Append(connection->cursors, weakref) < 0) {
Py_DECREF(weakref);
return 0;
}

Py_DECREF(weakref);
return 1;
}

/*[clinic input]
_sqlite3.Cursor.__init__ as pysqlite_cursor_init

Expand Down Expand Up @@ -138,6 +160,10 @@ pysqlite_cursor_init_impl(pysqlite_Cursor *self,
return -1;
}

if (!register_cursor(connection, (PyObject *)self)) {
return -1;
}

self->initialized = 1;

return 0;
Expand Down Expand Up @@ -1151,11 +1177,12 @@ pysqlite_cursor_iternext(PyObject *op)

self->locked = 1; // GH-80254: Prevent recursive use of cursors.
PyObject *row = _pysqlite_fetch_one_row(self);
self->locked = 0;
if (row == NULL) {
self->locked = 0;
return NULL;
}
int rc = stmt_step(stmt);
self->locked = 0;
if (rc == SQLITE_DONE) {
if (self->statement->is_dml) {
self->rowcount = (long)sqlite3_changes(self->connection->db);
Expand Down
Loading