Skip to content

Commit ebf203e

Browse files
committed
Support glob patterns in open_datatree(group=...) for selective group loading
When the group parameter contains glob metacharacters (*, ?, [), filter which groups are opened instead of re-rooting the tree. This avoids loading the entire hierarchy when only a subset is needed. Adds shared utilities _is_glob_pattern, _filter_group_paths, and _resolve_group_and_filter in common.py. Updates NetCDF4, H5NetCDF, and Zarr backends to use the discover-filter-open pipeline. Includes unit tests for the utilities and integration tests across all backends. Closes #11196
1 parent 06a4a55 commit ebf203e

File tree

6 files changed

+294
-33
lines changed

6 files changed

+294
-33
lines changed

xarray/backends/api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,8 +1097,12 @@ def open_datatree(
10971097
Additional keyword arguments passed on to the engine open function.
10981098
For example:
10991099
1100-
- 'group': path to the group in the given file to open as the root group as
1101-
a str.
1100+
- 'group': path to the group in the given file to open as the root
1101+
group as a str. If the string contains glob metacharacters
1102+
(``*``, ``?``, ``[``), it is interpreted as a pattern and only
1103+
groups whose paths match are loaded (along with their ancestors).
1104+
For example, ``group="*/sweep_0"`` loads every ``sweep_0`` one
1105+
level deep while skipping sibling groups.
11021106
- 'lock': resource lock to use when reading data from disk. Only
11031107
relevant when using dask or another form of parallelism. By default,
11041108
appropriate locks are chosen to safely read and write files with the
@@ -1344,8 +1348,12 @@ def open_groups(
13441348
Additional keyword arguments passed on to the engine open function.
13451349
For example:
13461350
1347-
- 'group': path to the group in the given file to open as the root group as
1348-
a str.
1351+
- 'group': path to the group in the given file to open as the root
1352+
group as a str. If the string contains glob metacharacters
1353+
(``*``, ``?``, ``[``), it is interpreted as a pattern and only
1354+
groups whose paths match are loaded (along with their ancestors).
1355+
For example, ``group="*/sweep_0"`` loads every ``sweep_0`` one
1356+
level deep while skipping sibling groups.
13491357
- 'lock': resource lock to use when reading data from disk. Only
13501358
relevant when using dask or another form of parallelism. By default,
13511359
appropriate locks are chosen to safely read and write files with the

xarray/backends/common.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,37 @@ def _iter_nc_groups(root, parent="/"):
248248
yield from _iter_nc_groups(group, parent=gpath)
249249

250250

251+
def _is_glob_pattern(pattern: str) -> bool:
252+
return any(c in pattern for c in "*?[")
253+
254+
255+
def _filter_group_paths(group_paths: Iterable[str], pattern: str) -> list[str]:
256+
from xarray.core.treenode import NodePath
257+
258+
matched: set[str] = {"/"}
259+
for path in group_paths:
260+
np_ = NodePath(path)
261+
if np_.match(pattern):
262+
matched.add(path)
263+
for parent in np_.parents:
264+
p = str(parent)
265+
if p:
266+
matched.add(p)
267+
268+
return [p for p in group_paths if p in matched]
269+
270+
271+
def _resolve_group_and_filter(
272+
group: str | None,
273+
all_group_paths: list[str],
274+
) -> tuple[str | None, list[str]]:
275+
if group is None:
276+
return None, all_group_paths
277+
if _is_glob_pattern(group):
278+
return None, _filter_group_paths(all_group_paths, group)
279+
return group, all_group_paths
280+
281+
251282
def find_root_and_group(ds):
252283
"""Find the root and group name of a netCDF4/h5netcdf dataset."""
253284
hierarchy = ()

xarray/backends/h5netcdf_.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,11 @@ def open_groups_as_dict(
655655
open_kwargs: dict[str, Any] | None = None,
656656
**kwargs,
657657
) -> dict[str, Dataset]:
658-
from xarray.backends.common import _iter_nc_groups
658+
from xarray.backends.common import (
659+
_is_glob_pattern,
660+
_iter_nc_groups,
661+
_resolve_group_and_filter,
662+
)
659663
from xarray.core.treenode import NodePath
660664
from xarray.core.utils import close_on_error
661665

@@ -664,10 +668,12 @@ def open_groups_as_dict(
664668
emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims)
665669

666670
filename_or_obj = _normalize_filename_or_obj(filename_or_obj)
671+
672+
effective_group = None if (group and _is_glob_pattern(group)) else group
667673
store = H5NetCDFStore.open(
668674
filename_or_obj,
669675
format=format,
670-
group=group,
676+
group=effective_group,
671677
lock=lock,
672678
invalid_netcdf=invalid_netcdf,
673679
phony_dims=phony_dims,
@@ -678,15 +684,17 @@ def open_groups_as_dict(
678684
open_kwargs=open_kwargs,
679685
)
680686

681-
# Check for a group and make it a parent if it exists
682-
if group:
683-
parent = NodePath("/") / NodePath(group)
687+
if effective_group:
688+
parent = NodePath("/") / NodePath(effective_group)
684689
else:
685690
parent = NodePath("/")
686691

687692
manager = store._manager
693+
all_group_paths = list(_iter_nc_groups(store.ds, parent=parent))
694+
_, filtered_paths = _resolve_group_and_filter(group, all_group_paths)
695+
688696
groups_dict = {}
689-
for path_group in _iter_nc_groups(store.ds, parent=parent):
697+
for path_group in filtered_paths:
690698
group_store = H5NetCDFStore(manager, group=path_group, **kwargs)
691699
store_entrypoint = StoreBackendEntrypoint()
692700
with close_on_error(group_store):
@@ -701,7 +709,7 @@ def open_groups_as_dict(
701709
decode_timedelta=decode_timedelta,
702710
)
703711

704-
if group:
712+
if effective_group:
705713
group_name = str(NodePath(path_group).relative_to(parent))
706714
else:
707715
group_name = str(NodePath(path_group))

xarray/backends/netCDF4_.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -859,13 +859,19 @@ def open_groups_as_dict(
859859
autoclose=False,
860860
**kwargs,
861861
) -> dict[str, Dataset]:
862-
from xarray.backends.common import _iter_nc_groups
862+
from xarray.backends.common import (
863+
_is_glob_pattern,
864+
_iter_nc_groups,
865+
_resolve_group_and_filter,
866+
)
863867
from xarray.core.treenode import NodePath
864868

865869
filename_or_obj = _normalize_path(filename_or_obj)
870+
871+
effective_group = None if (group and _is_glob_pattern(group)) else group
866872
store = NetCDF4DataStore.open(
867873
filename_or_obj,
868-
group=group,
874+
group=effective_group,
869875
format=format,
870876
clobber=clobber,
871877
diskless=diskless,
@@ -875,15 +881,17 @@ def open_groups_as_dict(
875881
autoclose=autoclose,
876882
)
877883

878-
# Check for a group and make it a parent if it exists
879-
if group:
880-
parent = NodePath("/") / NodePath(group)
884+
if effective_group:
885+
parent = NodePath("/") / NodePath(effective_group)
881886
else:
882887
parent = NodePath("/")
883888

884889
manager = store._manager
890+
all_group_paths = list(_iter_nc_groups(store.ds, parent=parent))
891+
_, filtered_paths = _resolve_group_and_filter(group, all_group_paths)
892+
885893
groups_dict = {}
886-
for path_group in _iter_nc_groups(store.ds, parent=parent):
894+
for path_group in filtered_paths:
887895
group_store = NetCDF4DataStore(manager, group=path_group, **kwargs)
888896
store_entrypoint = StoreBackendEntrypoint()
889897
with close_on_error(group_store):
@@ -897,7 +905,7 @@ def open_groups_as_dict(
897905
use_cftime=use_cftime,
898906
decode_timedelta=decode_timedelta,
899907
)
900-
if group:
908+
if effective_group:
901909
group_name = str(NodePath(path_group).relative_to(parent))
902910
else:
903911
group_name = str(NodePath(path_group))

xarray/backends/zarr.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,10 +1900,13 @@ def open_datatree(
19001900
zarr_format=None,
19011901
max_concurrency: int | None = None,
19021902
) -> DataTree:
1903+
from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter
1904+
19031905
filename_or_obj = _normalize_path(filename_or_obj)
19041906

1905-
if group:
1906-
parent = str(NodePath("/") / NodePath(group))
1907+
effective_group = None if (group and _is_glob_pattern(group)) else group
1908+
if effective_group:
1909+
parent = str(NodePath("/") / NodePath(effective_group))
19071910
else:
19081911
parent = str(NodePath("/"))
19091912

@@ -1964,8 +1967,11 @@ def open_datatree(
19641967
zarr_version=zarr_version,
19651968
zarr_format=zarr_format,
19661969
)
1970+
all_paths = list(stores.keys())
1971+
_, filtered_paths = _resolve_group_and_filter(group, all_paths)
19671972
groups_dict = {}
1968-
for path_group, store in stores.items():
1973+
for path_group in filtered_paths:
1974+
store = stores[path_group]
19691975
store_entrypoint = StoreBackendEntrypoint()
19701976
with close_on_error(store):
19711977
group_ds = store_entrypoint.open_dataset(
@@ -1978,7 +1984,7 @@ def open_datatree(
19781984
use_cftime=use_cftime,
19791985
decode_timedelta=decode_timedelta,
19801986
)
1981-
if group:
1987+
if effective_group:
19821988
group_name = str(NodePath(path_group).relative_to(parent))
19831989
else:
19841990
group_name = str(NodePath(path_group))
@@ -2045,6 +2051,16 @@ async def _open_datatree_from_stores_async(
20452051
if parent_path in group_children:
20462052
group_children[parent_path][child_name] = member
20472053

2054+
# Filter groups when glob pattern is used
2055+
from xarray.backends.common import _resolve_group_and_filter
2056+
2057+
effective_group, filtered_paths = _resolve_group_and_filter(
2058+
group, list(group_async.keys())
2059+
)
2060+
filtered_set = set(filtered_paths)
2061+
group_async = {k: v for k, v in group_async.items() if k in filtered_set}
2062+
group_children = {k: v for k, v in group_children.items() if k in filtered_set}
2063+
20482064
# Phase 2: Open each group — wrap async objects, run CPU decode in threads.
20492065
async def open_one(path_group: str) -> tuple[str, Dataset]:
20502066
async_grp = group_async[path_group]
@@ -2091,7 +2107,7 @@ def _cpu_open():
20912107
)
20922108

20932109
ds = await loop.run_in_executor(executor, _cpu_open)
2094-
if group:
2110+
if effective_group:
20952111
group_name = str(NodePath(path_group).relative_to(parent))
20962112
else:
20972113
group_name = str(NodePath(path_group))
@@ -2132,11 +2148,13 @@ def open_groups_as_dict(
21322148
zarr_version=None,
21332149
zarr_format=None,
21342150
) -> dict[str, Dataset]:
2151+
from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter
2152+
21352153
filename_or_obj = _normalize_path(filename_or_obj)
21362154

2137-
# Check for a group and make it a parent if it exists
2138-
if group:
2139-
parent = str(NodePath("/") / NodePath(group))
2155+
effective_group = None if (group and _is_glob_pattern(group)) else group
2156+
if effective_group:
2157+
parent = str(NodePath("/") / NodePath(effective_group))
21402158
else:
21412159
parent = str(NodePath("/"))
21422160

@@ -2153,8 +2171,11 @@ def open_groups_as_dict(
21532171
zarr_format=zarr_format,
21542172
)
21552173

2174+
_, filtered_paths = _resolve_group_and_filter(group, list(stores.keys()))
2175+
21562176
groups_dict = {}
2157-
for path_group, store in stores.items():
2177+
for path_group in filtered_paths:
2178+
store = stores[path_group]
21582179
store_entrypoint = StoreBackendEntrypoint()
21592180

21602181
with close_on_error(store):
@@ -2168,7 +2189,7 @@ def open_groups_as_dict(
21682189
use_cftime=use_cftime,
21692190
decode_timedelta=decode_timedelta,
21702191
)
2171-
if group:
2192+
if effective_group:
21722193
group_name = str(NodePath(path_group).relative_to(parent))
21732194
else:
21742195
group_name = str(NodePath(path_group))
@@ -2200,11 +2221,13 @@ async def open_groups_as_dict_async(
22002221
This mirrors open_groups_as_dict but parallelizes per-group Dataset opening,
22012222
which can significantly reduce latency on high-RTT object stores.
22022223
"""
2224+
from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter
2225+
22032226
filename_or_obj = _normalize_path(filename_or_obj)
22042227

2205-
# Determine parent group path context
2206-
if group:
2207-
parent = str(NodePath("/") / NodePath(group))
2228+
effective_group = None if (group and _is_glob_pattern(group)) else group
2229+
if effective_group:
2230+
parent = str(NodePath("/") / NodePath(effective_group))
22082231
else:
22092232
parent = str(NodePath("/"))
22102233

@@ -2221,6 +2244,9 @@ async def open_groups_as_dict_async(
22212244
zarr_format=zarr_format,
22222245
)
22232246

2247+
_, filtered_paths = _resolve_group_and_filter(group, list(stores.keys()))
2248+
filtered_set = set(filtered_paths)
2249+
22242250
loop = asyncio.get_running_loop()
22252251
max_workers = min(len(stores), 10) if stores else 1
22262252
executor = ThreadPoolExecutor(
@@ -2244,15 +2270,17 @@ def _load_sync():
22442270
)
22452271

22462272
ds = await loop.run_in_executor(executor, _load_sync)
2247-
if group:
2273+
if effective_group:
22482274
group_name = str(NodePath(path_group).relative_to(parent))
22492275
else:
22502276
group_name = str(NodePath(path_group))
22512277
return group_name, ds
22522278

22532279
try:
22542280
tasks = [
2255-
open_one(path_group, store) for path_group, store in stores.items()
2281+
open_one(path_group, store)
2282+
for path_group, store in stores.items()
2283+
if path_group in filtered_set
22562284
]
22572285
results = await asyncio.gather(*tasks)
22582286
finally:

0 commit comments

Comments
 (0)