diff --git a/Lib/shelve.py b/Lib/shelve.py index 1010be1e09d702..63e4f116b52c4a 100644 --- a/Lib/shelve.py +++ b/Lib/shelve.py @@ -106,6 +106,17 @@ def __init__(self, dict, protocol=None, writeback=False, self.serializer = serializer self.deserializer = deserializer + @staticmethod + def _validate_serialized_value(serialized_value, original_value): + if (serialized_value is None or + not isinstance(serialized_value, (bytes, str))): + if serialized_value is None: + invalid_type = "None" + else: + invalid_type = type(serialized_value).__name__ + msg = f"Serializer must return bytes or str, not {invalid_type}" + raise ShelveError(msg) + def __iter__(self): for k in self.dict.keys(): yield k.decode(self.keyencoding) @@ -135,6 +146,7 @@ def __setitem__(self, key, value): if self.writeback: self.cache[key] = value serialized_value = self.serializer(value, self._protocol) + self._validate_serialized_value(serialized_value, value) self.dict[key.encode(self.keyencoding)] = serialized_value def __delitem__(self, key): diff --git a/Lib/test/test_dbm_gnu.py b/Lib/test/test_dbm_gnu.py index 66268c42a300b5..c092357d92a635 100644 --- a/Lib/test/test_dbm_gnu.py +++ b/Lib/test/test_dbm_gnu.py @@ -217,6 +217,25 @@ def test_localized_error(self): create_empty_file(os.path.join(d, 'test')) self.assertRaises(gdbm.error, gdbm.open, filename, 'r') + def test_type_errors(self): + self.g = gdbm.open(filename, 'c') + with self.assertRaisesRegex( + TypeError, "^a bytes-like object is required, not 'int'$", + ): + self.g[123] + with self.assertRaisesRegex( + TypeError, "^gdbm key must be bytes or str, not 'int'$", + ): + 123 in self.g + with self.assertRaisesRegex( + TypeError, "^gdbm key must be bytes or str, not 'NoneType'$", + ): + self.g[None] = 123 + with self.assertRaisesRegex( + TypeError, "^gdbm value must be bytes or str, not 'int'$", + ): + self.g['foo'] = 123 + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_dbm_ndbm.py b/Lib/test/test_dbm_ndbm.py index e0f31c9a9a337d..01da376b599b3d 100644 --- a/Lib/test/test_dbm_ndbm.py +++ b/Lib/test/test_dbm_ndbm.py @@ -160,6 +160,25 @@ def test_clear(self): self.assertNotIn(k, db) self.assertEqual(len(db), 0) + def test_type_errors(self): + with dbm.ndbm.open(self.filename, 'c') as db: + with self.assertRaisesRegex( + TypeError, "^a bytes-like object is required, not 'int'$", + ): + db[123] + with self.assertRaisesRegex( + TypeError, "^dbm key must be bytes or str, not 'int'$", + ): + 123 in db + with self.assertRaisesRegex( + TypeError, "^dbm key must be bytes or str, not 'NoneType'$", + ): + db[None] = 123 + with self.assertRaisesRegex( + TypeError, "^dbm value must be bytes or str, not 'int'$", + ): + db['foo'] = 123 + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 64609ab9dd9a62..931f1ed8e6147f 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self): def serializer(obj, protocol): if isinstance(obj, (bytes, bytearray, str)): if protocol == 5: + if isinstance(obj, bytearray): + return bytes(obj) # DBM backends expect bytes return obj return type(obj).__name__ elif isinstance(obj, array.array): @@ -223,11 +225,10 @@ def deserializer(data): ) def test_custom_incomplete_serializer_and_deserializer(self): - dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") os.mkdir(self.dirname) self.addCleanup(os_helper.rmtree, self.dirname) - with self.assertRaises(dbm_sqlite3.error): + with self.assertRaises(shelve.ShelveError): def serializer(obj, protocol=None): pass @@ -430,6 +431,79 @@ def setUp(self): dbm._defaultmod = self.dbm_mod +class TestShelveValidation(unittest.TestCase): + dirname = os_helper.TESTFN + fname = os.path.join(dirname, os_helper.TESTFN) + + def setup_test_dir(self): + os_helper.rmtree(self.dirname) + os.mkdir(self.dirname) + + def setUp(self): + self.addCleanup(setattr, dbm, "_defaultmod", dbm._defaultmod) + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def test_serializer_unsupported_return_type(self): + def int_serializer(obj, protocol=None): + return 3 + + def none_serializer(obj, protocol=None): + return None + + def deserializer(data): + if isinstance(data, bytes): + return data.decode("utf-8") + else: + return data + + for module in dbm_iterator(): + self.setup_test_dir() + dbm._defaultmod = module + with module.open(self.fname, "c"): + pass + self.assertEqual(module.__name__, dbm.whichdb(self.fname)) + + with shelve.open(self.fname, serializer=none_serializer, + deserializer=deserializer) as s: + with self.assertRaises(shelve.ShelveError) as cm: + s["key"] = "value" + self.assertEqual("Serializer must return bytes or str, not None", + f"{cm.exception}") + + with shelve.open(self.fname, serializer=int_serializer, + deserializer=deserializer,) as s: + with self.assertRaises(shelve.ShelveError) as cm: + s["key"] = "value" + self.assertEqual("Serializer must return bytes or str, not int", + f"{cm.exception}") + + def test_shelve_type_compatibility(self): + for module in dbm_iterator(): + self.setup_test_dir() + dbm._defaultmod = module + with shelve.Shelf(module.open(self.fname, "c")) as shelf: + shelf["string"] = "hello" + shelf["bytes"] = b"world" + shelf["number"] = 42 + shelf["list"] = [1, 2, 3] + shelf["dict"] = {"key": "value"} + shelf["set"] = {1, 2, 3} + shelf["tuple"] = (1, 2, 3) + shelf["complex"] = 1 + 2j + shelf["bytearray"] = bytearray(b"test") + shelf["array"] = array.array("i", [1, 2, 3]) + self.assertEqual(shelf["string"], "hello") + self.assertEqual(shelf["bytes"], b"world") + self.assertEqual(shelf["number"], 42) + self.assertEqual(shelf["list"], [1, 2, 3]) + self.assertEqual(shelf["dict"], {"key": "value"}) + self.assertEqual(shelf["set"], {1, 2, 3}) + self.assertEqual(shelf["tuple"], (1, 2, 3)) + self.assertEqual(shelf["complex"], 1 + 2j) + self.assertEqual(shelf["bytearray"], bytearray(b"test")) + self.assertEqual(shelf["array"], array.array("i", [1, 2, 3])) + from test import mapping_tests for proto in range(pickle.HIGHEST_PROTOCOL + 1): diff --git a/Misc/NEWS.d/next/Library/2025-11-16-15-29-49.gh-issue-137899.JnbEmT.rst b/Misc/NEWS.d/next/Library/2025-11-16-15-29-49.gh-issue-137899.JnbEmT.rst new file mode 100644 index 00000000000000..12faf24a0a6d46 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-11-16-15-29-49.gh-issue-137899.JnbEmT.rst @@ -0,0 +1 @@ +The :mod:`shelve` module will now provide descriptive error messages to better distinguish key and value type errors. Patch by Furkan Onder. diff --git a/Modules/_dbmmodule.c b/Modules/_dbmmodule.c index f88861fa24423b..2b0a73a22d3c6c 100644 --- a/Modules/_dbmmodule.c +++ b/Modules/_dbmmodule.c @@ -231,8 +231,9 @@ dbm_ass_sub_lock_held(PyObject *self, PyObject *v, PyObject *w) dbmobject *dp = dbmobject_CAST(self); if ( !PyArg_Parse(v, "s#", &krec.dptr, &tmp_size) ) { - PyErr_SetString(PyExc_TypeError, - "dbm mappings have bytes or string keys only"); + PyErr_Format(PyExc_TypeError, + "dbm key must be bytes or str, not '%T'", + v); return -1; } _dbm_state *state = PyType_GetModuleState(Py_TYPE(dp)); @@ -258,8 +259,9 @@ dbm_ass_sub_lock_held(PyObject *self, PyObject *v, PyObject *w) } } else { if ( !PyArg_Parse(w, "s#", &drec.dptr, &tmp_size) ) { - PyErr_SetString(PyExc_TypeError, - "dbm mappings have bytes or string elements only"); + PyErr_Format(PyExc_TypeError, + "dbm value must be bytes or str, not '%T'", + w); return -1; } drec.dsize = tmp_size; @@ -369,8 +371,7 @@ dbm_contains_lock_held(PyObject *self, PyObject *arg) } else if (!PyBytes_Check(arg)) { PyErr_Format(PyExc_TypeError, - "dbm key must be bytes or string, not %.100s", - Py_TYPE(arg)->tp_name); + "dbm key must be bytes or str, not '%T'", arg); return -1; } else { diff --git a/Modules/_gdbmmodule.c b/Modules/_gdbmmodule.c index 72f568ceb06987..b1bbcc86df5892 100644 --- a/Modules/_gdbmmodule.c +++ b/Modules/_gdbmmodule.c @@ -237,12 +237,14 @@ gdbm_bool(PyObject *op) // This function is needed to support PY_SSIZE_T_CLEAN. // Return 1 on success, same to PyArg_Parse(). static int -parse_datum(PyObject *o, datum *d, const char *failmsg) +parse_datum(PyObject *o, datum *d, const char *item_name) { Py_ssize_t size; if (!PyArg_Parse(o, "s#", &d->dptr, &size)) { - if (failmsg != NULL) { - PyErr_SetString(PyExc_TypeError, failmsg); + if (item_name) { + PyErr_Format(PyExc_TypeError, + "gdbm %s must be bytes or str, not '%T'", + item_name, o); } return 0; } @@ -318,11 +320,10 @@ static int gdbm_ass_sub_lock_held(PyObject *op, PyObject *v, PyObject *w) { datum krec, drec; - const char *failmsg = "gdbm mappings have bytes or string indices only"; gdbmobject *dp = _gdbmobject_CAST(op); _gdbm_state *state = PyType_GetModuleState(Py_TYPE(dp)); - if (!parse_datum(v, &krec, failmsg)) { + if (!parse_datum(v, &krec, "key")) { return -1; } if (dp->di_dbm == NULL) { @@ -343,7 +344,7 @@ gdbm_ass_sub_lock_held(PyObject *op, PyObject *v, PyObject *w) } } else { - if (!parse_datum(w, &drec, failmsg)) { + if (!parse_datum(w, &drec, "value")) { return -1; } errno = 0; @@ -491,8 +492,7 @@ gdbm_contains_lock_held(PyObject *self, PyObject *arg) } else if (!PyBytes_Check(arg)) { PyErr_Format(PyExc_TypeError, - "gdbm key must be bytes or string, not %.100s", - Py_TYPE(arg)->tp_name); + "gdbm key must be bytes or str, not '%T'", arg); return -1; } else {