Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/core/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import pickle
import sqlite3

import pytest

from burr.core import State
Expand Down Expand Up @@ -131,6 +134,97 @@ def test_sqlite_persister_list_app_ids_without_initialize_raises_runtime_error()
persister.cleanup()


def test_sqlite_persister_from_values(tmp_path):
persister = SQLLitePersister.from_values(
db_path=str(tmp_path / "test.db"), table_name="test_table"
)
try:
persister.initialize()
persister.save("pk", "app_id", 1, "position", State({"key": "value"}), "completed")
loaded = persister.load("pk", "app_id")
assert loaded["state"] == State({"key": "value"})
finally:
persister.cleanup()


def test_sqlite_persister_from_config(tmp_path):
persister = SQLLitePersister.from_config(
{"db_path": str(tmp_path / "test.db"), "table_name": "test_table"}
)
try:
persister.initialize()
assert persister.is_initialized()
assert persister.table_name == "test_table"
finally:
persister.cleanup()


def test_sqlite_persister_context_manager_closes_connection(tmp_path):
with SQLLitePersister(db_path=str(tmp_path / "test.db"), table_name="test_table") as persister:
persister.initialize()
with pytest.raises(sqlite3.ProgrammingError):
persister.connection.cursor()


def test_sqlite_persister_copy_creates_independent_connection(tmp_path):
persister = SQLLitePersister(db_path=str(tmp_path / "test.db"), table_name="test_table")
persister.initialize()
persister.save("pk", "app_id", 1, "position", State({"key": "value"}), "completed")
copied = persister.copy()
try:
assert copied.connection is not persister.connection
assert copied.db_path == persister.db_path
assert copied.table_name == persister.table_name
loaded = copied.load("pk", "app_id")
assert loaded["state"] == State({"key": "value"})
finally:
persister.cleanup()
copied.cleanup()


def test_sqlite_persister_pickle_roundtrip_reconnects(tmp_path):
persister = SQLLitePersister(db_path=str(tmp_path / "test.db"), table_name="test_table")
persister.initialize()
persister.save("pk", "app_id", 1, "position", State({"key": "value"}), "completed")
unpickled = pickle.loads(pickle.dumps(persister))
try:
loaded = unpickled.load("pk", "app_id")
assert loaded["state"] == State({"key": "value"})
finally:
persister.cleanup()
unpickled.cleanup()


def test_sqlite_persister_load_latest_when_app_id_is_none():
persister = SQLLitePersister(db_path=":memory:", table_name="test_table")
try:
persister.initialize()
persister.save("pk", "app_id1", 1, "position", State({"key": "value1"}), "completed")
persister.save("pk", "app_id2", 1, "position", State({"key": "value2"}), "completed")
loaded = persister.load("pk", None)
assert loaded is not None
assert loaded["app_id"] in ("app_id1", "app_id2")
expected_value = "value1" if loaded["app_id"] == "app_id1" else "value2"
assert loaded["state"] == State({"key": expected_value})
finally:
persister.cleanup()


def test_sqlite_persister_load_specific_sequence_id():
persister = SQLLitePersister(db_path=":memory:", table_name="test_table")
try:
persister.initialize()
persister.save("pk", "app_id", 1, "first", State({"key": "value1"}), "completed")
persister.save("pk", "app_id", 2, "second", State({"key": "value2"}), "completed")
loaded = persister.load("pk", "app_id", sequence_id=1)
assert loaded["sequence_id"] == 1
assert loaded["position"] == "first"
assert loaded["state"] == State({"key": "value1"})
assert persister.load("pk", "app_id", sequence_id=99) is None
finally:
persister.cleanup()


@pytest.mark.parametrize(
"method_name,kwargs",
[
Expand Down
Loading