Skip to content
13 changes: 5 additions & 8 deletions pyoverkiz/action_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,15 @@ def __init__(
executor: Callable[
[list[Action], ExecutionMode | None, str | None], Coroutine[None, None, str]
],
delay: float = 0.5,
max_actions: int = 20,
settings: ActionQueueSettings | None = None,
) -> None:
"""Initialize the action queue.

:param executor: Async function to execute batched actions
:param delay: Seconds to wait before auto-flushing (default 0.5)
:param max_actions: Maximum actions per batch before forced flush (default 20)
:param settings: Queue configuration (uses defaults if None)
"""
self._executor = executor
self._delay = delay
self._max_actions = max_actions
self._settings = settings or ActionQueueSettings()

self._pending_actions: list[Action] = []
self._pending_mode: ExecutionMode | None = None
Expand Down Expand Up @@ -188,7 +185,7 @@ async def add(
self._pending_waiters.append(waiter)

# If we hit max actions, flush immediately
if len(self._pending_actions) >= self._max_actions:
if len(self._pending_actions) >= self._settings.max_actions:
# Prepare the current batch for flushing (which includes the actions
# we just added). If we already flushed due to mode change, this is
# a second batch.
Expand All @@ -208,7 +205,7 @@ async def _delayed_flush(self) -> None:
"""Wait for the delay period, then flush the queue."""
waiters: list[QueuedExecution] = []
try:
await asyncio.sleep(self._delay)
await asyncio.sleep(self._settings.delay)
async with self._lock:
if not self._pending_actions:
return
Expand Down
36 changes: 12 additions & 24 deletions pyoverkiz/auth/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def build_auth_strategy(

if server == Server.SOMFY_EUROPE:
return SomfyAuthStrategy(
_ensure_username_password(credentials),
_ensure_credentials(credentials, UsernamePasswordCredentials),
session,
server_config,
ssl_context,
Expand All @@ -51,23 +51,23 @@ def build_auth_strategy(
Server.SAUTER_COZYTOUCH,
}:
return CozytouchAuthStrategy(
_ensure_username_password(credentials),
_ensure_credentials(credentials, UsernamePasswordCredentials),
session,
server_config,
ssl_context,
)

if server == Server.NEXITY:
return NexityAuthStrategy(
_ensure_username_password(credentials),
_ensure_credentials(credentials, UsernamePasswordCredentials),
session,
server_config,
ssl_context,
)

if server == Server.REXEL:
return RexelAuthStrategy(
_ensure_rexel(credentials),
_ensure_credentials(credentials, RexelOAuthCodeCredentials),
session,
server_config,
ssl_context,
Expand All @@ -79,7 +79,7 @@ def build_auth_strategy(
credentials, session, server_config, ssl_context
)
return BearerTokenAuthStrategy(
_ensure_token(credentials),
_ensure_credentials(credentials, TokenCredentials),
session,
server_config,
ssl_context,
Expand All @@ -91,29 +91,17 @@ def build_auth_strategy(
return BearerTokenAuthStrategy(credentials, session, server_config, ssl_context)

return SessionLoginStrategy(
_ensure_username_password(credentials),
_ensure_credentials(credentials, UsernamePasswordCredentials),
session,
server_config,
ssl_context,
)


def _ensure_username_password(credentials: Credentials) -> UsernamePasswordCredentials:
"""Validate that credentials are username/password based."""
if not isinstance(credentials, UsernamePasswordCredentials):
raise TypeError("UsernamePasswordCredentials are required for this server.")
return credentials


def _ensure_token(credentials: Credentials) -> TokenCredentials:
"""Validate that credentials carry a bearer token."""
if not isinstance(credentials, TokenCredentials):
raise TypeError("TokenCredentials are required for this server.")
return credentials


def _ensure_rexel(credentials: Credentials) -> RexelOAuthCodeCredentials:
"""Validate that credentials are of Rexel OAuth code type."""
if not isinstance(credentials, RexelOAuthCodeCredentials):
raise TypeError("RexelOAuthCodeCredentials are required for this server.")
def _ensure_credentials[C: Credentials](
credentials: Credentials, expected: type[C]
) -> C:
"""Validate that credentials match the expected type."""
if not isinstance(credentials, expected):
raise TypeError(f"{expected.__name__} are required for this server.")
return credentials
9 changes: 4 additions & 5 deletions pyoverkiz/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
)
from pyoverkiz.models import (
Action,
ActionGroup,
Device,
Event,
Execution,
Expand All @@ -47,6 +46,7 @@
HistoryExecution,
Option,
OptionParameter,
PersistedActionGroup,
Place,
ProtocolType,
ServerConfig,
Expand Down Expand Up @@ -222,8 +222,7 @@ def __init__(
queue_settings.validate()
self._action_queue = ActionQueue(
executor=self._execute_action_group_direct,
delay=queue_settings.delay,
max_actions=queue_settings.max_actions,
settings=queue_settings,
)

self._auth = build_auth_strategy(
Expand Down Expand Up @@ -577,10 +576,10 @@ async def cancel_execution(self, exec_id: str) -> None:
await self._delete(f"exec/current/setup/{exec_id}")

@retry_on_auth_error
async def get_action_groups(self) -> list[ActionGroup]:
async def get_action_groups(self) -> list[PersistedActionGroup]:
"""List action groups persisted on the server."""
response = await self._get("actionGroups")
return converter.structure(decamelize(response), list[ActionGroup])
return converter.structure(decamelize(response), list[PersistedActionGroup])

@retry_on_auth_error
async def get_places(self) -> Place:
Expand Down
27 changes: 15 additions & 12 deletions pyoverkiz/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ class States:

_states: list[State]
_index: dict[str, State]
_pos: dict[str, int]

def __init__(self, states: list[State] | None = None) -> None:
"""Create a States container from a list of State objects or empty."""
self._states = list(states) if states else []
self._index = {state.name: state for state in self._states}
self._pos = {state.name: i for i, state in enumerate(self._states)}

def __iter__(self) -> Iterator[State]:
"""Return an iterator over contained State objects."""
Expand All @@ -159,10 +161,10 @@ def __setitem__(self, name: str, state: State) -> None:
"""Set or append a State identified by name."""
if state.name != name:
raise ValueError(f"State name {state.name!r} does not match key {name!r}")
if name in self._index:
idx = self._states.index(self._index[name])
self._states[idx] = state
if name in self._pos:
self._states[self._pos[name]] = state
else:
self._pos[name] = len(self._states)
self._states.append(state)
self._index[name] = state

Expand Down Expand Up @@ -509,24 +511,25 @@ class ActionGroup:
"""

actions: list[Action] = field(factory=list)
creation_time: int | None = None
last_update_time: int | None = None
label: str = field(repr=obfuscate_string, default="")
label: str | None = field(repr=obfuscate_string, default=None)
metadata: str | None = None
shortcut: bool | None = None
notification_type_mask: int | None = None
notification_condition: str | None = None
notification_text: str | None = None
notification_title: str | None = None
oid: str | None = field(repr=obfuscate_id, default=None)

def __attrs_post_init__(self) -> None:
"""Default label to empty string when None."""
if self.label is None:
self.label = ""

@define(kw_only=True)
class PersistedActionGroup(ActionGroup):
"""A server-persisted action group returned by the /actionGroups endpoint."""

oid: str = field(repr=obfuscate_id)
creation_time: int = 0
last_update_time: int = 0

@property
def id(self) -> str | None:
def id(self) -> str:
"""Alias for oid."""
return self.oid

Expand Down
34 changes: 20 additions & 14 deletions tests/test_action_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from pyoverkiz.action_queue import ActionQueue, QueuedExecution
from pyoverkiz.action_queue import ActionQueue, ActionQueueSettings, QueuedExecution
from pyoverkiz.enums import ExecutionMode, OverkizCommand
from pyoverkiz.models import Action, Command

Expand All @@ -24,7 +24,7 @@ async def executor(actions, mode, label):
@pytest.mark.asyncio
async def test_action_queue_single_action(mock_executor):
"""Test queue with a single action."""
queue = ActionQueue(executor=mock_executor, delay=0.1)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.1))

action = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -45,7 +45,7 @@ async def test_action_queue_single_action(mock_executor):
@pytest.mark.asyncio
async def test_action_queue_batching(mock_executor):
"""Test that multiple actions are batched together."""
queue = ActionQueue(executor=mock_executor, delay=0.2)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.2))

actions = [
Action(
Expand Down Expand Up @@ -75,7 +75,9 @@ async def test_action_queue_batching(mock_executor):
@pytest.mark.asyncio
async def test_action_queue_max_actions_flush(mock_executor):
"""Test that queue flushes when max actions is reached."""
queue = ActionQueue(executor=mock_executor, delay=10.0, max_actions=3)
queue = ActionQueue(
executor=mock_executor, settings=ActionQueueSettings(delay=10.0, max_actions=3)
)

actions = [
Action(
Expand Down Expand Up @@ -112,8 +114,8 @@ async def test_action_queue_max_actions_flush(mock_executor):

@pytest.mark.asyncio
async def test_action_queue_mode_change_flush(mock_executor):
"""Test that queue flushes when execution mode changes."""
queue = ActionQueue(executor=mock_executor, delay=0.5)
"""Test that queue flushes when command mode changes."""
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.5))

action = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -140,7 +142,7 @@ async def test_action_queue_mode_change_flush(mock_executor):
@pytest.mark.asyncio
async def test_action_queue_label_change_flush(mock_executor):
"""Test that queue flushes when label changes."""
queue = ActionQueue(executor=mock_executor, delay=0.5)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.5))

action = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -167,7 +169,7 @@ async def test_action_queue_label_change_flush(mock_executor):
@pytest.mark.asyncio
async def test_action_queue_duplicate_device_merge(mock_executor):
"""Test that queue merges commands for duplicate devices."""
queue = ActionQueue(executor=mock_executor, delay=0.5)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.5))

action1 = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -191,7 +193,7 @@ async def test_action_queue_duplicate_device_merge(mock_executor):
@pytest.mark.asyncio
async def test_action_queue_duplicate_device_merge_order(mock_executor):
"""Test that command order is preserved when merging."""
queue = ActionQueue(executor=mock_executor, delay=0.1)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.1))

action1 = Action(
device_url="io://1234-5678-9012/1",
Expand Down Expand Up @@ -219,7 +221,7 @@ async def test_action_queue_duplicate_device_merge_does_not_mutate_inputs(
mock_executor,
):
"""Test that merge behavior does not mutate caller-owned Action commands."""
queue = ActionQueue(executor=mock_executor, delay=0.1)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.1))

action1 = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -240,7 +242,9 @@ async def test_action_queue_duplicate_device_merge_does_not_mutate_inputs(
@pytest.mark.asyncio
async def test_action_queue_manual_flush(mock_executor):
"""Test manual flush of the queue."""
queue = ActionQueue(executor=mock_executor, delay=10.0) # Long delay
queue = ActionQueue(
executor=mock_executor, settings=ActionQueueSettings(delay=10.0)
) # Long delay

action = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -261,7 +265,9 @@ async def test_action_queue_manual_flush(mock_executor):
@pytest.mark.asyncio
async def test_action_queue_shutdown(mock_executor):
"""Test that shutdown flushes pending actions."""
queue = ActionQueue(executor=mock_executor, delay=10.0)
queue = ActionQueue(
executor=mock_executor, settings=ActionQueueSettings(delay=10.0)
)

action = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -284,7 +290,7 @@ async def test_action_queue_error_propagation(mock_executor):
# Make executor raise an exception
mock_executor.side_effect = ValueError("API Error")

queue = ActionQueue(executor=mock_executor, delay=0.1)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.1))

action = Action(
device_url="io://1234-5678-9012/1",
Expand All @@ -306,7 +312,7 @@ async def test_action_queue_error_propagation(mock_executor):
async def test_action_queue_get_pending_count():
"""Test getting pending action count."""
mock_executor = AsyncMock(return_value="exec-123")
queue = ActionQueue(executor=mock_executor, delay=0.5)
queue = ActionQueue(executor=mock_executor, settings=ActionQueueSettings(delay=0.5))

assert queue.get_pending_count() == 0

Expand Down
Loading