Skip to content

Commit 06a4a55

Browse files
committed
Use native async for DataTree open, threads only for CPU decode
Replace the per-group thread+zarr_sync pattern in _open_datatree_from_stores_async with a two-phase approach: Phase 1: Single async_root.members(max_depth=None) call discovers all groups AND their array members in one pass, replacing both _iter_zarr_groups_async and per-group _fetch_members calls. Phase 2: Wrap AsyncArray/AsyncGroup in sync Array/Group (zero-cost), inject pre-fetched members into ZarrStore, run only CPU-bound decode_cf_variables in thread pool. Results (laptop, 60-node OSN store): - zarr_sync calls: 122 → 2 - members() calls: 61 → 1 - ~15-20% faster open_datatree (2.08s → 1.77s with indexes)
1 parent 06f3df8 commit 06a4a55

File tree

1 file changed

+54
-16
lines changed

1 file changed

+54
-16
lines changed

xarray/backends/zarr.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,12 +1931,9 @@ def open_datatree(
19311931
zarr_format=zarr_format,
19321932
)
19331933

1934-
group_paths = zarr_sync(_iter_zarr_groups_async(zarr_group, parent=parent))
1935-
19361934
return zarr_sync(
19371935
self._open_datatree_from_stores_async(
19381936
zarr_group=zarr_group,
1939-
group_paths=group_paths,
19401937
parent=parent,
19411938
group=group,
19421939
mode=mode,
@@ -1991,7 +1988,6 @@ def open_datatree(
19911988
async def _open_datatree_from_stores_async(
19921989
self,
19931990
zarr_group,
1994-
group_paths: list[str],
19951991
parent: str,
19961992
group: str | None,
19971993
*,
@@ -2008,25 +2004,65 @@ async def _open_datatree_from_stores_async(
20082004
decode_timedelta=None,
20092005
max_concurrency: int | None = None,
20102006
) -> DataTree:
2011-
"""Open datatree groups concurrently using a dedicated executor."""
2007+
"""Open datatree using native async for I/O, threads only for CPU decode."""
20122008
if max_concurrency is None:
20132009
max_concurrency = 10
20142010

20152011
loop = asyncio.get_running_loop()
20162012
executor = ThreadPoolExecutor(
2017-
max_workers=max_concurrency, thread_name_prefix="xarray"
2013+
max_workers=max_concurrency, thread_name_prefix="xarray-cpu"
20182014
)
20192015

2016+
from zarr import Array as ZarrSyncArray
2017+
from zarr import Group as ZarrSyncGroup
2018+
from zarr.core.group import AsyncGroup as ZarrAsyncGroup
2019+
2020+
async_root = zarr_group._async_group
2021+
parent_nodepath = NodePath(parent)
2022+
2023+
# Phase 1: Walk tree, collect groups + per-group members in one async pass.
2024+
# This replaces both _iter_zarr_groups_async and per-group _fetch_members.
2025+
group_async: dict[str, ZarrAsyncGroup] = {
2026+
str(parent_nodepath): async_root,
2027+
}
2028+
group_children: dict[str, dict] = {str(parent_nodepath): {}}
2029+
2030+
async for rel_path, member in async_root.members(max_depth=None):
2031+
full_path = str(parent_nodepath / rel_path)
2032+
2033+
if isinstance(member, ZarrAsyncGroup):
2034+
group_async[full_path] = member
2035+
group_children[full_path] = {}
2036+
2037+
parts = rel_path.rsplit("/", 1)
2038+
child_name = parts[-1]
2039+
parent_rel = parts[0] if len(parts) > 1 else ""
2040+
parent_path = (
2041+
str(parent_nodepath / parent_rel)
2042+
if parent_rel
2043+
else str(parent_nodepath)
2044+
)
2045+
if parent_path in group_children:
2046+
group_children[parent_path][child_name] = member
2047+
2048+
# Phase 2: Open each group — wrap async objects, run CPU decode in threads.
20202049
async def open_one(path_group: str) -> tuple[str, Dataset]:
2021-
def _sync_open():
2022-
if path_group == parent:
2023-
group_store = zarr_group
2024-
else:
2025-
rel_path = path_group.removeprefix(f"{parent}/").removeprefix("/")
2026-
group_store = zarr_group[rel_path]
2050+
async_grp = group_async[path_group]
2051+
children = group_children.get(path_group, {})
2052+
2053+
def _cpu_open():
2054+
sync_group = ZarrSyncGroup(async_grp)
2055+
sync_members = {
2056+
name: (
2057+
ZarrSyncGroup(child)
2058+
if isinstance(child, ZarrAsyncGroup)
2059+
else ZarrSyncArray(child)
2060+
)
2061+
for name, child in children.items()
2062+
}
20272063

20282064
store = ZarrStore(
2029-
group_store,
2065+
sync_group,
20302066
mode,
20312067
consolidate_on_close,
20322068
append_dim=None,
@@ -2036,8 +2072,10 @@ def _sync_open():
20362072
close_store_on_close=close_store_on_close,
20372073
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
20382074
align_chunks=False,
2039-
cache_members=True,
2075+
cache_members=False,
20402076
)
2077+
store._members = sync_members
2078+
store._cache_members = True
20412079

20422080
store_entrypoint = StoreBackendEntrypoint()
20432081
with close_on_error(store):
@@ -2052,7 +2090,7 @@ def _sync_open():
20522090
decode_timedelta=decode_timedelta,
20532091
)
20542092

2055-
ds = await loop.run_in_executor(executor, _sync_open)
2093+
ds = await loop.run_in_executor(executor, _cpu_open)
20562094
if group:
20572095
group_name = str(NodePath(path_group).relative_to(parent))
20582096
else:
@@ -2067,7 +2105,7 @@ async def collect_result(path_group: str) -> None:
20672105

20682106
try:
20692107
async with asyncio.TaskGroup() as tg:
2070-
for path_group in group_paths:
2108+
for path_group in sorted(group_async.keys()):
20712109
tg.create_task(collect_result(path_group))
20722110
finally:
20732111
executor.shutdown(wait=True, cancel_futures=True)

0 commit comments

Comments
 (0)