Skip to content

Commit 51505d9

Browse files
committed
Simplify ActionQueue: extract _merge_actions, deduplicate _delayed_flush, fix shutdown lock
- Extract _merge_actions() helper to deduplicate action-merging logic in add() - Replace hand-rolled snapshot in _delayed_flush with _prepare_flush() call - Move cancelled task await outside lock in shutdown() to prevent potential deadlock
1 parent 14f13ae commit 51505d9

1 file changed

Lines changed: 35 additions & 48 deletions

File tree

pyoverkiz/action_queue.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,26 @@ def __init__(
107107
self._lock = asyncio.Lock()
108108

109109
@staticmethod
110-
def _copy_action(action: Action) -> Action:
111-
"""Return an `Action` copy with an independent commands list.
112-
113-
The queue merges commands for duplicate devices, so caller-owned action
114-
instances must be copied to avoid mutating user input while batching.
115-
"""
116-
return Action(device_url=action.device_url, commands=list(action.commands))
110+
def _merge_actions(
111+
target: list[Action],
112+
index: dict[str, Action],
113+
source: list[Action],
114+
*,
115+
copy: bool = False,
116+
) -> None:
117+
"""Merge *source* actions into *target*, combining commands for duplicate devices."""
118+
for action in source:
119+
existing = index.get(action.device_url)
120+
if existing is None:
121+
merged = (
122+
Action(device_url=action.device_url, commands=list(action.commands))
123+
if copy
124+
else action
125+
)
126+
target.append(merged)
127+
index[action.device_url] = merged
128+
else:
129+
existing.commands.extend(action.commands)
117130

118131
async def add(
119132
self,
@@ -146,14 +159,7 @@ async def add(
146159

147160
normalized_actions: list[Action] = []
148161
normalized_index: dict[str, Action] = {}
149-
for action in actions:
150-
existing = normalized_index.get(action.device_url)
151-
if existing is None:
152-
action_copy = self._copy_action(action)
153-
normalized_actions.append(action_copy)
154-
normalized_index[action.device_url] = action_copy
155-
else:
156-
existing.commands.extend(action.commands)
162+
self._merge_actions(normalized_actions, normalized_index, actions, copy=True)
157163

158164
async with self._lock:
159165
# If mode or label changes, flush existing queue first
@@ -162,18 +168,10 @@ async def add(
162168
):
163169
batches_to_execute.append(self._prepare_flush())
164170

165-
# Add actions to pending queue
166-
pending_index = {
167-
pending_action.device_url: pending_action
168-
for pending_action in self._pending_actions
169-
}
170-
for action in normalized_actions:
171-
pending = pending_index.get(action.device_url)
172-
if pending is None:
173-
self._pending_actions.append(action)
174-
pending_index[action.device_url] = action
175-
else:
176-
pending.commands.extend(action.commands)
171+
pending_index = {a.device_url: a for a in self._pending_actions}
172+
self._merge_actions(
173+
self._pending_actions, pending_index, normalized_actions
174+
)
177175
self._pending_mode = mode
178176
self._pending_label = label
179177

@@ -207,25 +205,13 @@ async def _delayed_flush(self) -> None:
207205
try:
208206
await asyncio.sleep(self._settings.delay)
209207
async with self._lock:
210-
if not self._pending_actions:
208+
batch = self._prepare_flush()
209+
if not batch[0]:
211210
return
211+
actions, mode, label, waiters = batch
212212

213-
# Take snapshot and clear state while holding lock
214-
actions = self._pending_actions
215-
mode = self._pending_mode
216-
label = self._pending_label
217-
waiters = self._pending_waiters
218-
219-
self._pending_actions = []
220-
self._pending_mode = None
221-
self._pending_label = None
222-
self._pending_waiters = []
223-
self._flush_task = None
224-
225-
# Execute outside the lock
226213
await self._execute_batch(actions, mode, label, waiters)
227214
except asyncio.CancelledError as exc:
228-
# Ensure all waiters are notified if this task is cancelled
229215
for waiter in waiters:
230216
waiter.set_exception(exc)
231217
raise
@@ -317,19 +303,20 @@ def get_pending_count(self) -> int:
317303

318304
async def shutdown(self) -> None:
319305
"""Shutdown the queue, flushing any pending actions."""
306+
cancelled_task: asyncio.Task[None] | None = None
320307
batch_to_execute = None
321308
async with self._lock:
322309
if self._flush_task and not self._flush_task.done():
323-
task = self._flush_task
324-
task.cancel()
310+
cancelled_task = self._flush_task
311+
cancelled_task.cancel()
325312
self._flush_task = None
326-
# Wait for cancellation to complete
327-
with contextlib.suppress(asyncio.CancelledError):
328-
await task
329313

330314
if self._pending_actions:
331315
batch_to_execute = self._prepare_flush()
332316

333-
# Execute outside the lock
317+
if cancelled_task:
318+
with contextlib.suppress(asyncio.CancelledError):
319+
await cancelled_task
320+
334321
if batch_to_execute:
335322
await self._execute_batch(*batch_to_execute)

0 commit comments

Comments
 (0)