@@ -1931,12 +1931,9 @@ def open_datatree(
19311931 zarr_format = zarr_format ,
19321932 )
19331933
1934- group_paths = zarr_sync (_iter_zarr_groups_async (zarr_group , parent = parent ))
1935-
19361934 return zarr_sync (
19371935 self ._open_datatree_from_stores_async (
19381936 zarr_group = zarr_group ,
1939- group_paths = group_paths ,
19401937 parent = parent ,
19411938 group = group ,
19421939 mode = mode ,
@@ -1991,7 +1988,6 @@ def open_datatree(
19911988 async def _open_datatree_from_stores_async (
19921989 self ,
19931990 zarr_group ,
1994- group_paths : list [str ],
19951991 parent : str ,
19961992 group : str | None ,
19971993 * ,
@@ -2008,25 +2004,65 @@ async def _open_datatree_from_stores_async(
20082004 decode_timedelta = None ,
20092005 max_concurrency : int | None = None ,
20102006 ) -> DataTree :
2011- """Open datatree groups concurrently using a dedicated executor ."""
2007+ """Open datatree using native async for I/O, threads only for CPU decode ."""
20122008 if max_concurrency is None :
20132009 max_concurrency = 10
20142010
20152011 loop = asyncio .get_running_loop ()
20162012 executor = ThreadPoolExecutor (
2017- max_workers = max_concurrency , thread_name_prefix = "xarray"
2013+ max_workers = max_concurrency , thread_name_prefix = "xarray-cpu "
20182014 )
20192015
2016+ from zarr import Array as ZarrSyncArray
2017+ from zarr import Group as ZarrSyncGroup
2018+ from zarr .core .group import AsyncGroup as ZarrAsyncGroup
2019+
2020+ async_root = zarr_group ._async_group
2021+ parent_nodepath = NodePath (parent )
2022+
2023+ # Phase 1: Walk tree, collect groups + per-group members in one async pass.
2024+ # This replaces both _iter_zarr_groups_async and per-group _fetch_members.
2025+ group_async : dict [str , ZarrAsyncGroup ] = {
2026+ str (parent_nodepath ): async_root ,
2027+ }
2028+ group_children : dict [str , dict ] = {str (parent_nodepath ): {}}
2029+
2030+ async for rel_path , member in async_root .members (max_depth = None ):
2031+ full_path = str (parent_nodepath / rel_path )
2032+
2033+ if isinstance (member , ZarrAsyncGroup ):
2034+ group_async [full_path ] = member
2035+ group_children [full_path ] = {}
2036+
2037+ parts = rel_path .rsplit ("/" , 1 )
2038+ child_name = parts [- 1 ]
2039+ parent_rel = parts [0 ] if len (parts ) > 1 else ""
2040+ parent_path = (
2041+ str (parent_nodepath / parent_rel )
2042+ if parent_rel
2043+ else str (parent_nodepath )
2044+ )
2045+ if parent_path in group_children :
2046+ group_children [parent_path ][child_name ] = member
2047+
2048+ # Phase 2: Open each group — wrap async objects, run CPU decode in threads.
20202049 async def open_one (path_group : str ) -> tuple [str , Dataset ]:
2021- def _sync_open ():
2022- if path_group == parent :
2023- group_store = zarr_group
2024- else :
2025- rel_path = path_group .removeprefix (f"{ parent } /" ).removeprefix ("/" )
2026- group_store = zarr_group [rel_path ]
2050+ async_grp = group_async [path_group ]
2051+ children = group_children .get (path_group , {})
2052+
2053+ def _cpu_open ():
2054+ sync_group = ZarrSyncGroup (async_grp )
2055+ sync_members = {
2056+ name : (
2057+ ZarrSyncGroup (child )
2058+ if isinstance (child , ZarrAsyncGroup )
2059+ else ZarrSyncArray (child )
2060+ )
2061+ for name , child in children .items ()
2062+ }
20272063
20282064 store = ZarrStore (
2029- group_store ,
2065+ sync_group ,
20302066 mode ,
20312067 consolidate_on_close ,
20322068 append_dim = None ,
@@ -2036,8 +2072,10 @@ def _sync_open():
20362072 close_store_on_close = close_store_on_close ,
20372073 use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask ,
20382074 align_chunks = False ,
2039- cache_members = True ,
2075+ cache_members = False ,
20402076 )
2077+ store ._members = sync_members
2078+ store ._cache_members = True
20412079
20422080 store_entrypoint = StoreBackendEntrypoint ()
20432081 with close_on_error (store ):
@@ -2052,7 +2090,7 @@ def _sync_open():
20522090 decode_timedelta = decode_timedelta ,
20532091 )
20542092
2055- ds = await loop .run_in_executor (executor , _sync_open )
2093+ ds = await loop .run_in_executor (executor , _cpu_open )
20562094 if group :
20572095 group_name = str (NodePath (path_group ).relative_to (parent ))
20582096 else :
@@ -2067,7 +2105,7 @@ async def collect_result(path_group: str) -> None:
20672105
20682106 try :
20692107 async with asyncio .TaskGroup () as tg :
2070- for path_group in group_paths :
2108+ for path_group in sorted ( group_async . keys ()) :
20712109 tg .create_task (collect_result (path_group ))
20722110 finally :
20732111 executor .shutdown (wait = True , cancel_futures = True )
0 commit comments