Skip to content

Commit 27ba9f6

Browse files
committed
Address PR review: TaskGroup, max_concurrency, and open_dataset_async
- Replace asyncio.gather with asyncio.TaskGroup for better error handling (cancels outstanding tasks on error) - Add max_concurrency parameter to open_datatree for controlling parallel I/O operations (defaults to 10) - Add StoreBackendEntrypoint.open_dataset_async method - Add test for open_dataset_async equivalence
1 parent 980c882 commit 27ba9f6

5 files changed

Lines changed: 98 additions & 22 deletions

File tree

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ New Features
1616

1717
- Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit
1818
all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`).
19+
- Added ``max_concurrency`` parameter to :py:func:`open_datatree` to control
20+
the maximum number of concurrent I/O operations when opening groups in parallel
21+
with the Zarr backend (:pull:`10742`).
1922
By `Alfonso Ladino <https://github.com/aladinor>`_.
2023

2124
Breaking Changes

xarray/backends/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,7 @@ def open_datatree(
943943
chunked_array_type: str | None = None,
944944
from_array_kwargs: dict[str, Any] | None = None,
945945
backend_kwargs: dict[str, Any] | None = None,
946+
max_concurrency: int | None = None,
946947
**kwargs,
947948
) -> DataTree:
948949
"""
@@ -1075,6 +1076,13 @@ def open_datatree(
10751076
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
10761077
For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed
10771078
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
1079+
max_concurrency : int, optional
1080+
Maximum number of concurrent I/O operations when opening groups in
1081+
parallel. This limits the number of groups that are loaded simultaneously.
1082+
Useful for controlling resource usage with large datatrees or stores
1083+
that may have limitations on concurrent access (e.g., icechunk).
1084+
Only used by backends that support parallel loading (currently Zarr v3).
1085+
If None (default), the backend uses its default value (typically 10).
10781086
backend_kwargs: dict
10791087
Additional keyword arguments passed on to the engine open function,
10801088
equivalent to `**kwargs`.
@@ -1135,6 +1143,9 @@ def open_datatree(
11351143
)
11361144
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
11371145

1146+
if max_concurrency is not None:
1147+
kwargs["max_concurrency"] = max_concurrency
1148+
11381149
backend_tree = backend.open_datatree(
11391150
filename_or_obj,
11401151
drop_variables=drop_variables,

xarray/backends/store.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from collections.abc import Iterable
45
from typing import TYPE_CHECKING
56

@@ -72,5 +73,37 @@ def open_dataset(
7273

7374
return ds
7475

76+
async def open_dataset_async(
77+
self,
78+
filename_or_obj: T_PathFileOrDataStore,
79+
*,
80+
mask_and_scale=True,
81+
decode_times=True,
82+
concat_characters=True,
83+
decode_coords=True,
84+
drop_variables: str | Iterable[str] | None = None,
85+
set_indexes: bool = True,
86+
use_cftime=None,
87+
decode_timedelta=None,
88+
) -> Dataset:
89+
"""Async version of open_dataset.
90+
91+
Offloads the entire open_dataset operation to a thread to avoid blocking
92+
the event loop. This is necessary because decode_cf_variables can trigger
93+
data reads (e.g., for time decoding) which may use synchronous I/O.
94+
"""
95+
return await asyncio.to_thread(
96+
self.open_dataset,
97+
filename_or_obj,
98+
mask_and_scale=mask_and_scale,
99+
decode_times=decode_times,
100+
concat_characters=concat_characters,
101+
decode_coords=decode_coords,
102+
drop_variables=drop_variables,
103+
set_indexes=set_indexes,
104+
use_cftime=use_cftime,
105+
decode_timedelta=decode_timedelta,
106+
)
107+
75108

76109
BACKEND_ENTRYPOINTS["store"] = (None, StoreBackendEntrypoint)

xarray/backends/zarr.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,7 @@ def open_datatree(
18901890
storage_options=None,
18911891
zarr_version=None,
18921892
zarr_format=None,
1893+
max_concurrency: int | None = None,
18931894
) -> DataTree:
18941895
filename_or_obj = _normalize_path(filename_or_obj)
18951896

@@ -1927,6 +1928,7 @@ def open_datatree(
19271928
drop_variables=drop_variables,
19281929
use_cftime=use_cftime,
19291930
decode_timedelta=decode_timedelta,
1931+
max_concurrency=max_concurrency,
19301932
)
19311933
)
19321934
else:
@@ -1965,6 +1967,7 @@ async def _open_datatree_from_stores_async(
19651967
drop_variables: str | Iterable[str] | None = None,
19661968
use_cftime=None,
19671969
decode_timedelta=None,
1970+
max_concurrency: int | None = None,
19681971
) -> DataTree:
19691972
"""Async helper to open datasets from pre-opened stores and create indexes.
19701973
@@ -1973,28 +1976,24 @@ async def _open_datatree_from_stores_async(
19731976
"""
19741977
from xarray.backends.api import _maybe_create_default_indexes_async
19751978

1976-
# Limit concurrent to_thread calls to avoid deadlocks with some stores
1977-
# (e.g., icechunk can deadlock when too many threads access it simultaneously)
1978-
sem = asyncio.Semaphore(10)
1979+
if max_concurrency is None:
1980+
max_concurrency = 10
1981+
sem = asyncio.Semaphore(max_concurrency)
19791982

19801983
async def open_one(path_group: str, store) -> tuple[str, Dataset]:
19811984
async with sem:
19821985
store_entrypoint = StoreBackendEntrypoint()
1983-
1984-
def _load_sync():
1985-
with close_on_error(store):
1986-
return store_entrypoint.open_dataset(
1987-
store,
1988-
mask_and_scale=mask_and_scale,
1989-
decode_times=decode_times,
1990-
concat_characters=concat_characters,
1991-
decode_coords=decode_coords,
1992-
drop_variables=drop_variables,
1993-
use_cftime=use_cftime,
1994-
decode_timedelta=decode_timedelta,
1995-
)
1996-
1997-
ds = await asyncio.to_thread(_load_sync)
1986+
with close_on_error(store):
1987+
ds = await store_entrypoint.open_dataset_async(
1988+
store,
1989+
mask_and_scale=mask_and_scale,
1990+
decode_times=decode_times,
1991+
concat_characters=concat_characters,
1992+
decode_coords=decode_coords,
1993+
drop_variables=drop_variables,
1994+
use_cftime=use_cftime,
1995+
decode_timedelta=decode_timedelta,
1996+
)
19981997
# Create indexes in parallel (within this group)
19991998
ds = await _maybe_create_default_indexes_async(ds)
20001999
if group:
@@ -2003,10 +2002,15 @@ def _load_sync():
20032002
group_name = str(NodePath(path_group))
20042003
return group_name, ds
20052004

2006-
# Open all datasets and create indexes concurrently
2007-
tasks = [open_one(path_group, store) for path_group, store in stores.items()]
2008-
results = await asyncio.gather(*tasks)
2009-
groups_dict = dict(results)
2005+
groups_dict: dict[str, Dataset] = {}
2006+
2007+
async def collect_result(path_group: str, store) -> None:
2008+
group_name, ds = await open_one(path_group, store)
2009+
groups_dict[group_name] = ds
2010+
2011+
async with asyncio.TaskGroup() as tg:
2012+
for path_group, store in stores.items():
2013+
tg.create_task(collect_result(path_group, store))
20102014

20112015
return datatree_from_dict_with_io_cleanup(groups_dict)
20122016

xarray/tests/test_backends_zarr_async.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import xarray as xr
1111
from xarray.backends.api import _maybe_create_default_indexes_async
12+
from xarray.backends.store import StoreBackendEntrypoint
1213
from xarray.backends.zarr import ZarrBackendEntrypoint
1314
from xarray.testing import assert_equal
1415
from xarray.tests import (
@@ -232,3 +233,27 @@ def test_sync_open_datatree_uses_async_internally(self, zarr_format):
232233
# For zarr v3, the async function should be called
233234
assert mock_async.call_count > 0
234235
assert_equal(dtree, dtree_loaded)
236+
237+
@pytest.mark.asyncio
238+
@requires_zarr_v3
239+
@parametrize_zarr_format
240+
async def test_store_backend_open_dataset_async_equivalence(self, zarr_format):
241+
"""Test that StoreBackendEntrypoint.open_dataset_async returns same result as sync."""
242+
from xarray.backends.zarr import ZarrStore
243+
244+
ds = create_dataset_with_coordinates(n_coords=3)
245+
246+
with self.create_zarr_store() as store:
247+
ds.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format)
248+
249+
zarr_store = ZarrStore.open_group(
250+
store,
251+
consolidated=False,
252+
zarr_format=zarr_format,
253+
)
254+
255+
store_entrypoint = StoreBackendEntrypoint()
256+
ds_sync = store_entrypoint.open_dataset(zarr_store)
257+
ds_async = await store_entrypoint.open_dataset_async(zarr_store)
258+
259+
assert_equal(ds_sync, ds_async)

0 commit comments

Comments
 (0)