Skip to content

Commit 9779596

Browse files
authored
Merge branch 'main' into glob-group-filtering
2 parents ada50d9 + 561e5e8 commit 9779596

File tree

3 files changed

+71
-11
lines changed

3 files changed

+71
-11
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ v2026.03.0 (unreleased)
1414
New Features
1515
~~~~~~~~~~~~
1616

17+
- Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit
18+
all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`).
19+
By `Alfonso Ladino <https://github.com/aladinor>`_.
1720

1821
Breaking Changes
1922
~~~~~~~~~~~~~~~~

xarray/core/datatree.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,28 @@ def _coord_variables(self) -> ChainMap[Hashable, Variable]:
588588
*(p._node_coord_variables_with_index for p in self.parents), # type: ignore[arg-type]
589589
)
590590

591+
@property
592+
def _coord_variables_all(self) -> ChainMap[Hashable, Variable]:
593+
return ChainMap(
594+
self._node_coord_variables,
595+
*(p._node_coord_variables for p in self.parents),
596+
)
597+
598+
def _resolve_inherit(
599+
self, inherit: bool | Literal["all_coords", "indexes"]
600+
) -> tuple[Mapping[Hashable, Variable], dict[Hashable, Index]]:
601+
"""Resolve the inherit parameter to (coord_vars, indexes)."""
602+
if inherit is False:
603+
return self._node_coord_variables, dict(self._node_indexes)
604+
if inherit is True or inherit == "indexes":
605+
return self._coord_variables, dict(self._indexes)
606+
if inherit == "all_coords":
607+
return self._coord_variables_all, dict(self._indexes)
608+
raise ValueError(
609+
f"Invalid value for inherit: {inherit!r}. "
610+
"Expected True, False, 'indexes', or 'all'."
611+
)
612+
591613
@property
592614
def _dims(self) -> ChainMap[Hashable, int]:
593615
return ChainMap(self._node_dims, *(p._node_dims for p in self.parents))
@@ -596,8 +618,12 @@ def _dims(self) -> ChainMap[Hashable, int]:
596618
def _indexes(self) -> ChainMap[Hashable, Index]:
597619
return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))
598620

599-
def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
600-
coord_vars = self._coord_variables if inherit else self._node_coord_variables
621+
def _to_dataset_view(
622+
self,
623+
rebuild_dims: bool,
624+
inherit: bool | Literal["all_coords", "indexes"] = True,
625+
) -> DatasetView:
626+
coord_vars, indexes = self._resolve_inherit(inherit)
601627
variables = dict(self._data_variables)
602628
variables |= coord_vars
603629
if rebuild_dims:
@@ -636,10 +662,10 @@ def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
636662
dims = dict(self._node_dims)
637663
return DatasetView._constructor(
638664
variables=variables,
639-
coord_names=set(self._coord_variables),
665+
coord_names=set(coord_vars),
640666
dims=dims,
641667
attrs=self._attrs,
642-
indexes=dict(self._indexes if inherit else self._node_indexes),
668+
indexes=indexes,
643669
encoding=self._encoding,
644670
close=None,
645671
)
@@ -669,30 +695,39 @@ def dataset(self, data: Dataset | None = None) -> None:
669695
# xarray-contrib/datatree
670696
ds = dataset
671697

672-
def to_dataset(self, inherit: bool = True) -> Dataset:
698+
def to_dataset(
699+
self, inherit: bool | Literal["all_coords", "indexes"] = True
700+
) -> Dataset:
673701
"""
674702
Return the data in this node as a new xarray.Dataset object.
675703
676704
Parameters
677705
----------
678-
inherit : bool, optional
679-
If False, only include coordinates and indexes defined at the level
680-
of this DataTree node, excluding any inherited coordinates and indexes.
706+
inherit : bool or {"all_coords", "indexes"}, default True
707+
Controls which coordinates are inherited from parent nodes.
708+
709+
- True or "indexes": inherit only indexed coordinates (default).
710+
- "all_coords": inherit all coordinates, including non-index coordinates.
711+
- False: only include coordinates defined at this node.
681712
682713
See Also
683714
--------
684715
DataTree.dataset
685716
"""
686-
coord_vars = self._coord_variables if inherit else self._node_coord_variables
717+
coord_vars, indexes = self._resolve_inherit(inherit)
687718
variables = dict(self._data_variables)
688719
variables |= coord_vars
689-
dims = calculate_dimensions(variables) if inherit else dict(self._node_dims)
720+
dims = (
721+
dict(self._node_dims)
722+
if inherit is False
723+
else calculate_dimensions(variables)
724+
)
690725
return Dataset._construct_direct(
691726
variables,
692727
set(coord_vars),
693728
dims,
694729
None if self._attrs is None else dict(self._attrs),
695-
dict(self._indexes if inherit else self._node_indexes),
730+
indexes,
696731
None if self._encoding is None else dict(self._encoding),
697732
None,
698733
)

xarray/tests/test_datatree.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,28 @@ def test_to_dataset_inherited(self) -> None:
243243
assert_identical(tree.to_dataset(inherit=True), base)
244244
assert_identical(subtree.to_dataset(inherit=True), sub_and_base)
245245

246+
def test_to_dataset_inherit_all(self) -> None:
247+
base = xr.Dataset(coords={"a": [1], "b": 2})
248+
sub = xr.Dataset(coords={"c": [3]})
249+
tree = DataTree.from_dict({"/": base, "/sub": sub})
250+
subtree = typing.cast(DataTree, tree["sub"])
251+
252+
expected = xr.Dataset(coords={"a": [1], "b": 2, "c": [3]})
253+
assert_identical(subtree.to_dataset(inherit="all_coords"), expected)
254+
assert_identical(tree.to_dataset(inherit="all_coords"), base)
255+
256+
mid = xr.Dataset(coords={"c": 3.0})
257+
leaf = xr.Dataset(coords={"d": [4]})
258+
deep = DataTree.from_dict({"/": base, "/mid": mid, "/mid/leaf": leaf})
259+
leaf_node = typing.cast(DataTree, deep["/mid/leaf"])
260+
result = leaf_node.to_dataset(inherit="all_coords")
261+
assert set(result.coords) == {"a", "b", "c", "d"}
262+
263+
def test_to_dataset_inherit_invalid(self) -> None:
264+
tree = DataTree()
265+
with pytest.raises(ValueError, match="Invalid value for inherit"):
266+
tree.to_dataset(inherit="invalid") # type: ignore[arg-type]
267+
246268

247269
class TestVariablesChildrenNameCollisions:
248270
def test_parent_already_has_variable_with_childs_name(self) -> None:

0 commit comments

Comments
 (0)