diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 2b907e35131d06..495ef97fa3c61c 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -24,11 +24,15 @@ import sqlite3 as sqlite import unittest +from test.support import import_helper from test.support.os_helper import TESTFN, unlink from .util import memory_database, cx_limit, with_tracebacks from .util import MemoryDatabaseMixin +# TODO(picnixz): increase test coverage for other callbacks +# such as 'func', 'step', 'finalize', and 'collation'. + class CollationTests(MemoryDatabaseMixin, unittest.TestCase): @@ -129,8 +133,55 @@ def test_deregister_collation(self): self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') +class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase): + + def assert_not_authorized(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"): + func(*args, **kwargs) + + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_authorizer_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_authorizer(lambda: None) + self.assert_not_authorized(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_authorizer_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + def handler(*a, **kw): + self.cx.set_authorizer(None) + raise ZeroDivisionError("hello world") + + self.cx.set_authorizer(handler) + self.assert_not_authorized(self.cx.execute, "select * from test") + + @with_tracebacks(OverflowError) + def test_authorizer_concurrent_mutation_with_overflown_value(self): + _testcapi = import_helper.import_module("_testcapi") + self.cx.execute("create table if not exists test(a number)") + + def handler(*a, **kw): + self.cx.set_authorizer(None) + # We expect 'int' at the C level, so this one will raise + # when converting via PyLong_Int(). + return _testcapi.INT_MAX + 1 + + self.cx.set_authorizer(handler) + self.assert_not_authorized(self.cx.execute, "select * from test") + + class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): + def assert_interrupted(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"): + func(*args, **kwargs) + def test_progress_handler_used(self): """ Test that the progress handler is invoked once it is set. @@ -219,11 +270,48 @@ def bad_progress(): create table foo(a, b) """) - def test_progress_handler_keyword_args(self): + def test_set_progress_handler_keyword_args(self): with self.assertRaisesRegex(TypeError, 'takes at least 1 positional argument'): self.con.set_progress_handler(progress_handler=lambda: None, n=1) + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_progress_handler_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_progress_handler(lambda x: None, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_progress_handler_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + def handler(*a, **kw): + self.cx.set_progress_handler(None, 1) + raise ZeroDivisionError("hello world") + + self.cx.set_progress_handler(handler, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + def test_progress_handler_concurrent_mutation_in_conversion(self): + self.cx.execute("create table if not exists test(a number)") + + class Handler: + def __bool__(_): + # clear the progress handler + self.cx.set_progress_handler(None, 1) + raise ValueError # force PyObject_True() to fail + + self.cx.set_progress_handler(Handler.__init__, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + # Running with tracebacks makes the second execution of this + # function raise another exception because of a database change. + class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): @@ -345,11 +433,40 @@ def test_trace_bad_handler(self): cx.set_trace_callback(lambda stmt: 5/0) cx.execute("select 1") - def test_trace_keyword_args(self): + def test_set_trace_callback_keyword_args(self): with self.assertRaisesRegex(TypeError, 'takes exactly 1 positional argument'): self.con.set_trace_callback(trace_callback=lambda: None) + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied, + # but for the trace handler, exceptions are never re-raised (only + # printed when needed). + + @with_tracebacks( + TypeError, + regex=r".*\(\) missing 6 required positional arguments", + ) + def test_trace_handler_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None) + self.cx.execute("select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_trace_callback_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + def handler(statement): + # clear the progress handler + self.cx.set_trace_callback(None) + raise ZeroDivisionError("hello world") + + self.cx.set_trace_callback(handler) + self.cx.execute("select * from test") + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst new file mode 100644 index 00000000000000..246979e91d76b5 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst @@ -0,0 +1,2 @@ +:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks +are mutated during a callback execution. Patch by Bénédikt Tran. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 83ff8e60557c07..cde06c965ad4e3 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -145,7 +145,8 @@ class _sqlite3.Connection "pysqlite_Connection *" "clinic_state()->ConnectionTyp /*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/ static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self); -static void free_callback_context(callback_context *ctx); +static void incref_callback_context(callback_context *ctx); +static void decref_callback_context(callback_context *ctx); static void set_callback_context(callback_context **ctx_pp, callback_context *ctx); static int connection_close(pysqlite_Connection *self); @@ -936,8 +937,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) args = _pysqlite_build_py_params(context, argc, argv); if (args) { callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); + incref_callback_context(ctx); py_retval = PyObject_CallObject(ctx->callable, args); + decref_callback_context(ctx); Py_DECREF(args); } @@ -964,7 +966,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) PyObject* stepmethod = NULL; callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); + incref_callback_context(ctx); aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); if (aggregate_instance == NULL) { @@ -1002,6 +1004,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) } error: + decref_callback_context(ctx); Py_XDECREF(stepmethod); Py_XDECREF(function_result); @@ -1033,9 +1036,10 @@ final_callback(sqlite3_context *context) PyObject *exc = PyErr_GetRaisedException(); callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); + incref_callback_context(ctx); function_result = PyObject_CallMethodNoArgs(*aggregate_instance, ctx->state->str_finalize); + decref_callback_context(ctx); Py_DECREF(*aggregate_instance); ok = 0; @@ -1107,6 +1111,7 @@ create_callback_context(PyTypeObject *cls, PyObject *callable) callback_context *ctx = PyMem_Malloc(sizeof(callback_context)); if (ctx != NULL) { PyObject *module = PyType_GetModule(cls); + ctx->refcount = 1; ctx->callable = Py_NewRef(callable); ctx->module = Py_NewRef(module); ctx->state = pysqlite_get_state(module); @@ -1118,11 +1123,33 @@ static void free_callback_context(callback_context *ctx) { assert(ctx != NULL); + assert(ctx->refcount == 0); Py_XDECREF(ctx->callable); Py_XDECREF(ctx->module); PyMem_Free(ctx); } +static inline void +incref_callback_context(callback_context *ctx) +{ + assert(PyGILState_Check()); + assert(ctx != NULL); + assert(ctx->refcount > 0); + ctx->refcount++; +} + +static inline void +decref_callback_context(callback_context *ctx) +{ + assert(PyGILState_Check()); + assert(ctx != NULL); + assert(ctx->refcount > 0); + ctx->refcount--; + if (ctx->refcount == 0) { + free_callback_context(ctx); + } +} + static void set_callback_context(callback_context **ctx_pp, callback_context *ctx) { @@ -1130,7 +1157,7 @@ set_callback_context(callback_context **ctx_pp, callback_context *ctx) callback_context *tmp = *ctx_pp; *ctx_pp = ctx; if (tmp != NULL) { - free_callback_context(tmp); + decref_callback_context(tmp); } } @@ -1141,7 +1168,7 @@ destructor_callback(void *ctx) // This function may be called without the GIL held, so we need to // ensure that we destroy 'ctx' with the GIL held. PyGILState_STATE gstate = PyGILState_Ensure(); - free_callback_context((callback_context *)ctx); + decref_callback_context((callback_context *)ctx); PyGILState_Release(gstate); } } @@ -1202,7 +1229,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, func_callback, NULL, NULL, - &destructor_callback); // will decref func + &destructor_callback); // will free 'ctx' if (rc != SQLITE_OK) { /* Workaround for SQLite bug: no error code or string is available here */ @@ -1226,7 +1253,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) PyGILState_STATE gilstate = PyGILState_Ensure(); callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); + incref_callback_context(ctx); int size = sizeof(PyObject *); PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); @@ -1258,6 +1285,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) Py_DECREF(res); exit: + decref_callback_context(ctx); Py_XDECREF(method); PyGILState_Release(gilstate); } @@ -1274,7 +1302,7 @@ value_callback(sqlite3_context *context) PyGILState_STATE gilstate = PyGILState_Ensure(); callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); + incref_callback_context(ctx); int size = sizeof(PyObject *); PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); @@ -1282,6 +1310,8 @@ value_callback(sqlite3_context *context) assert(*cls != NULL); PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value); + decref_callback_context(ctx); + if (res == NULL) { int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); set_sqlite_error(context, attr_err @@ -1403,7 +1433,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, 0, &step_callback, &final_callback, - &destructor_callback); // will decref func + &destructor_callback); // will free 'ctx' if (rc != SQLITE_OK) { /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(self->OperationalError, "Error creating aggregate"); @@ -1413,7 +1443,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, } static int -authorizer_callback(void *ctx, int action, const char *arg1, +authorizer_callback(void *ctx_vp, int action, const char *arg1, const char *arg2 , const char *dbname, const char *access_attempt_source) { @@ -1422,8 +1452,9 @@ authorizer_callback(void *ctx, int action, const char *arg1, PyObject *ret; int rc = SQLITE_DENY; - assert(ctx != NULL); - PyObject *callable = ((callback_context *)ctx)->callable; + callback_context *ctx = (callback_context *)ctx_vp; + incref_callback_context(ctx); + PyObject *callable = ctx->callable; ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname, access_attempt_source); @@ -1445,21 +1476,23 @@ authorizer_callback(void *ctx, int action, const char *arg1, Py_DECREF(ret); } + decref_callback_context(ctx); PyGILState_Release(gilstate); return rc; } static int -progress_callback(void *ctx) +progress_callback(void *ctx_vp) { PyGILState_STATE gilstate = PyGILState_Ensure(); int rc; PyObject *ret; - assert(ctx != NULL); - PyObject *callable = ((callback_context *)ctx)->callable; - ret = PyObject_CallNoArgs(callable); + callback_context *ctx = (callback_context *)ctx_vp; + incref_callback_context(ctx); + + ret = PyObject_CallNoArgs(ctx->callable); if (!ret) { /* abort query if error occurred */ rc = -1; @@ -1472,6 +1505,7 @@ progress_callback(void *ctx) print_or_clear_traceback(ctx); } + decref_callback_context(ctx); PyGILState_Release(gilstate); return rc; } @@ -1483,7 +1517,7 @@ progress_callback(void *ctx) * to ensure future compatibility. */ static int -trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) +trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) { if (type != SQLITE_TRACE_STMT) { return 0; @@ -1491,8 +1525,9 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) PyGILState_STATE gilstate = PyGILState_Ensure(); - assert(ctx != NULL); - pysqlite_state *state = ((callback_context *)ctx)->state; + callback_context *ctx = (callback_context *)ctx_vp; + incref_callback_context(ctx); + pysqlite_state *state = ctx->state; assert(state != NULL); PyObject *py_statement = NULL; @@ -1506,7 +1541,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) PyErr_SetString(state->DataError, "Expanded SQL string exceeds the maximum string length"); - print_or_clear_traceback((callback_context *)ctx); + print_or_clear_traceback(ctx); // Fall back to unexpanded sql py_statement = PyUnicode_FromString((const char *)sql); @@ -1516,16 +1551,16 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) sqlite3_free((void *)expanded_sql); } if (py_statement) { - PyObject *callable = ((callback_context *)ctx)->callable; - PyObject *ret = PyObject_CallOneArg(callable, py_statement); + PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement); Py_DECREF(py_statement); Py_XDECREF(ret); } if (PyErr_Occurred()) { - print_or_clear_traceback((callback_context *)ctx); + print_or_clear_traceback(ctx); } exit: + decref_callback_context(ctx); PyGILState_Release(gilstate); return 0; } @@ -1950,6 +1985,8 @@ collation_callback(void *context, int text1_length, const void *text1_data, PyObject* retval = NULL; long longval; int result = 0; + callback_context *ctx = (callback_context *)context; + incref_callback_context(ctx); /* This callback may be executed multiple times per sqlite3_step(). Bail if * the previous call failed */ @@ -1966,8 +2003,6 @@ collation_callback(void *context, int text1_length, const void *text1_data, goto finally; } - callback_context *ctx = (callback_context *)context; - assert(ctx != NULL); PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs. size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET; retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL); @@ -1989,6 +2024,7 @@ collation_callback(void *context, int text1_length, const void *text1_data, } finally: + decref_callback_context(ctx); Py_XDECREF(string1); Py_XDECREF(string2); Py_XDECREF(retval); diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 7a748ee3ea0c58..703396a0c8db53 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -36,6 +36,7 @@ typedef struct _callback_context PyObject *callable; PyObject *module; pysqlite_state *state; + Py_ssize_t refcount; } callback_context; enum autocommit_mode {