@@ -1890,6 +1890,7 @@ def open_datatree(
18901890 storage_options = None ,
18911891 zarr_version = None ,
18921892 zarr_format = None ,
1893+ max_concurrency : int | None = None ,
18931894 ) -> DataTree :
18941895 filename_or_obj = _normalize_path (filename_or_obj )
18951896
@@ -1927,6 +1928,7 @@ def open_datatree(
19271928 drop_variables = drop_variables ,
19281929 use_cftime = use_cftime ,
19291930 decode_timedelta = decode_timedelta ,
1931+ max_concurrency = max_concurrency ,
19301932 )
19311933 )
19321934 else :
@@ -1965,6 +1967,7 @@ async def _open_datatree_from_stores_async(
19651967 drop_variables : str | Iterable [str ] | None = None ,
19661968 use_cftime = None ,
19671969 decode_timedelta = None ,
1970+ max_concurrency : int | None = None ,
19681971 ) -> DataTree :
19691972 """Async helper to open datasets from pre-opened stores and create indexes.
19701973
@@ -1973,28 +1976,24 @@ async def _open_datatree_from_stores_async(
19731976 """
19741977 from xarray .backends .api import _maybe_create_default_indexes_async
19751978
1976- # Limit concurrent to_thread calls to avoid deadlocks with some stores
1977- # (e.g., icechunk can deadlock when too many threads access it simultaneously)
1978- sem = asyncio .Semaphore (10 )
1979+ if max_concurrency is None :
1980+ max_concurrency = 10
1981+ sem = asyncio .Semaphore (max_concurrency )
19791982
19801983 async def open_one (path_group : str , store ) -> tuple [str , Dataset ]:
19811984 async with sem :
19821985 store_entrypoint = StoreBackendEntrypoint ()
1983-
1984- def _load_sync ():
1985- with close_on_error (store ):
1986- return store_entrypoint .open_dataset (
1987- store ,
1988- mask_and_scale = mask_and_scale ,
1989- decode_times = decode_times ,
1990- concat_characters = concat_characters ,
1991- decode_coords = decode_coords ,
1992- drop_variables = drop_variables ,
1993- use_cftime = use_cftime ,
1994- decode_timedelta = decode_timedelta ,
1995- )
1996-
1997- ds = await asyncio .to_thread (_load_sync )
1986+ with close_on_error (store ):
1987+ ds = await store_entrypoint .open_dataset_async (
1988+ store ,
1989+ mask_and_scale = mask_and_scale ,
1990+ decode_times = decode_times ,
1991+ concat_characters = concat_characters ,
1992+ decode_coords = decode_coords ,
1993+ drop_variables = drop_variables ,
1994+ use_cftime = use_cftime ,
1995+ decode_timedelta = decode_timedelta ,
1996+ )
19981997 # Create indexes in parallel (within this group)
19991998 ds = await _maybe_create_default_indexes_async (ds )
20001999 if group :
@@ -2003,10 +2002,15 @@ def _load_sync():
20032002 group_name = str (NodePath (path_group ))
20042003 return group_name , ds
20052004
2006- # Open all datasets and create indexes concurrently
2007- tasks = [open_one (path_group , store ) for path_group , store in stores .items ()]
2008- results = await asyncio .gather (* tasks )
2009- groups_dict = dict (results )
2005+ groups_dict : dict [str , Dataset ] = {}
2006+
2007+ async def collect_result (path_group : str , store ) -> None :
2008+ group_name , ds = await open_one (path_group , store )
2009+ groups_dict [group_name ] = ds
2010+
2011+ async with asyncio .TaskGroup () as tg :
2012+ for path_group , store in stores .items ():
2013+ tg .create_task (collect_result (path_group , store ))
20102014
20112015 return datatree_from_dict_with_io_cleanup (groups_dict )
20122016
0 commit comments