Skip to content

Commit 8eb7da7

Browse files
committed
Fix async deadlock risks: use zarr built-in members(), run_in_executor safety
- Replace custom _iter_zarr_groups_async (~90 lines) with zarr's AsyncGroup.members(max_depth=None) to avoid sync fallback deadlock - Wrap _get_open_params and _build_group_members in run_in_executor in open_store_async to prevent reentrant sync() deadlock - Add dedicated executor to _maybe_create_default_indexes_async and create_indexes_async to avoid thread pool exhaustion on zarr's IO loop
1 parent d90bb91 commit 8eb7da7

2 files changed

Lines changed: 201 additions & 164 deletions

File tree

xarray/backends/api.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -361,21 +361,31 @@ def _datatree_from_backend_datatree(
361361

362362
async def create_indexes_async() -> dict[str, Dataset]:
363363
import asyncio
364+
from concurrent.futures import ThreadPoolExecutor
364365

365-
results: dict[str, Dataset] = {}
366-
tasks = [
367-
_create_index_for_node(path, node.dataset)
368-
for path, [node] in group_subtrees(backend_tree)
369-
]
370-
for fut in asyncio.as_completed(tasks):
371-
path, ds = await fut
372-
results[path] = ds
373-
return results
374-
375-
async def _create_index_for_node(
376-
path: str, ds: Dataset
377-
) -> tuple[str, Dataset]:
378-
return path, await _maybe_create_default_indexes_async(ds)
366+
executor = ThreadPoolExecutor(
367+
max_workers=10, thread_name_prefix="xarray-idx"
368+
)
369+
try:
370+
results: dict[str, Dataset] = {}
371+
372+
async def _create_index_for_node(
373+
path: str, ds: Dataset
374+
) -> tuple[str, Dataset]:
375+
return path, await _maybe_create_default_indexes_async(
376+
ds, executor=executor
377+
)
378+
379+
tasks = [
380+
_create_index_for_node(path, node.dataset)
381+
for path, [node] in group_subtrees(backend_tree)
382+
]
383+
for fut in asyncio.as_completed(tasks):
384+
path, ds = await fut
385+
results[path] = ds
386+
return results
387+
finally:
388+
executor.shutdown(wait=True, cancel_futures=True)
379389

380390
results = zarr_sync(create_indexes_async())
381391
tree = DataTree.from_dict(results, name=backend_tree.name)
@@ -417,12 +427,10 @@ async def _create_index_for_node(
417427
return tree
418428

419429

420-
async def _maybe_create_default_indexes_async(ds: Dataset) -> Dataset:
421-
"""Create default indexes for dimension coordinates asynchronously.
430+
async def _maybe_create_default_indexes_async(
431+
ds: Dataset, executor=None
432+
) -> Dataset:
422433

423-
This function parallelizes both data loading and index creation,
424-
which can significantly speed up opening datasets with many coordinates.
425-
"""
426434
import asyncio
427435

428436
to_index_names = [
@@ -434,11 +442,13 @@ async def _maybe_create_default_indexes_async(ds: Dataset) -> Dataset:
434442
if not to_index_names:
435443
return ds
436444

445+
loop = asyncio.get_running_loop()
446+
437447
async def load_var(var: Variable) -> Variable:
438448
try:
439449
return await var.load_async()
440450
except NotImplementedError:
441-
return await asyncio.to_thread(var.load)
451+
return await loop.run_in_executor(executor, var.load)
442452

443453
await asyncio.gather(*[load_var(ds.variables[name]) for name in to_index_names])
444454

0 commit comments

Comments
 (0)