diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index 18c65a0d1..e6bac3378 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import pickle +import sqlite3 + import pytest from burr.core import State @@ -131,6 +134,97 @@ def test_sqlite_persister_list_app_ids_without_initialize_raises_runtime_error() persister.cleanup() +def test_sqlite_persister_from_values(tmp_path): + persister = SQLLitePersister.from_values( + db_path=str(tmp_path / "test.db"), table_name="test_table" + ) + try: + persister.initialize() + persister.save("pk", "app_id", 1, "position", State({"key": "value"}), "completed") + loaded = persister.load("pk", "app_id") + assert loaded["state"] == State({"key": "value"}) + finally: + persister.cleanup() + + +def test_sqlite_persister_from_config(tmp_path): + persister = SQLLitePersister.from_config( + {"db_path": str(tmp_path / "test.db"), "table_name": "test_table"} + ) + try: + persister.initialize() + assert persister.is_initialized() + assert persister.table_name == "test_table" + finally: + persister.cleanup() + + +def test_sqlite_persister_context_manager_closes_connection(tmp_path): + with SQLLitePersister(db_path=str(tmp_path / "test.db"), table_name="test_table") as persister: + persister.initialize() + with pytest.raises(sqlite3.ProgrammingError): + persister.connection.cursor() + + +def test_sqlite_persister_copy_creates_independent_connection(tmp_path): + persister = SQLLitePersister(db_path=str(tmp_path / "test.db"), table_name="test_table") + persister.initialize() + persister.save("pk", "app_id", 1, "position", State({"key": "value"}), "completed") + copied = persister.copy() + try: + assert copied.connection is not persister.connection + assert copied.db_path == persister.db_path + assert copied.table_name == persister.table_name + loaded = copied.load("pk", "app_id") + assert loaded["state"] == State({"key": "value"}) + finally: + persister.cleanup() + copied.cleanup() + + +def test_sqlite_persister_pickle_roundtrip_reconnects(tmp_path): + persister = SQLLitePersister(db_path=str(tmp_path / "test.db"), table_name="test_table") + persister.initialize() + persister.save("pk", "app_id", 1, "position", State({"key": "value"}), "completed") + unpickled = pickle.loads(pickle.dumps(persister)) + try: + loaded = unpickled.load("pk", "app_id") + assert loaded["state"] == State({"key": "value"}) + finally: + persister.cleanup() + unpickled.cleanup() + + +def test_sqlite_persister_load_latest_when_app_id_is_none(): + persister = SQLLitePersister(db_path=":memory:", table_name="test_table") + try: + persister.initialize() + persister.save("pk", "app_id1", 1, "position", State({"key": "value1"}), "completed") + persister.save("pk", "app_id2", 1, "position", State({"key": "value2"}), "completed") + loaded = persister.load("pk", None) + assert loaded is not None + assert loaded["app_id"] in ("app_id1", "app_id2") + expected_value = "value1" if loaded["app_id"] == "app_id1" else "value2" + assert loaded["state"] == State({"key": expected_value}) + finally: + persister.cleanup() + + +def test_sqlite_persister_load_specific_sequence_id(): + persister = SQLLitePersister(db_path=":memory:", table_name="test_table") + try: + persister.initialize() + persister.save("pk", "app_id", 1, "first", State({"key": "value1"}), "completed") + persister.save("pk", "app_id", 2, "second", State({"key": "value2"}), "completed") + loaded = persister.load("pk", "app_id", sequence_id=1) + assert loaded["sequence_id"] == 1 + assert loaded["position"] == "first" + assert loaded["state"] == State({"key": "value1"}) + assert persister.load("pk", "app_id", sequence_id=99) is None + finally: + persister.cleanup() + + @pytest.mark.parametrize( "method_name,kwargs", [