@@ -588,6 +588,28 @@ def _coord_variables(self) -> ChainMap[Hashable, Variable]:
588588 * (p ._node_coord_variables_with_index for p in self .parents ), # type: ignore[arg-type]
589589 )
590590
591+ @property
592+ def _coord_variables_all (self ) -> ChainMap [Hashable , Variable ]:
593+ return ChainMap (
594+ self ._node_coord_variables ,
595+ * (p ._node_coord_variables for p in self .parents ),
596+ )
597+
598+ def _resolve_inherit (
599+ self , inherit : bool | Literal ["all_coords" , "indexes" ]
600+ ) -> tuple [Mapping [Hashable , Variable ], dict [Hashable , Index ]]:
601+ """Resolve the inherit parameter to (coord_vars, indexes)."""
602+ if inherit is False :
603+ return self ._node_coord_variables , dict (self ._node_indexes )
604+ if inherit is True or inherit == "indexes" :
605+ return self ._coord_variables , dict (self ._indexes )
606+ if inherit == "all_coords" :
607+ return self ._coord_variables_all , dict (self ._indexes )
608+ raise ValueError (
609+ f"Invalid value for inherit: { inherit !r} . "
610+ "Expected True, False, 'indexes', or 'all'."
611+ )
612+
591613 @property
592614 def _dims (self ) -> ChainMap [Hashable , int ]:
593615 return ChainMap (self ._node_dims , * (p ._node_dims for p in self .parents ))
@@ -596,8 +618,12 @@ def _dims(self) -> ChainMap[Hashable, int]:
596618 def _indexes (self ) -> ChainMap [Hashable , Index ]:
597619 return ChainMap (self ._node_indexes , * (p ._node_indexes for p in self .parents ))
598620
599- def _to_dataset_view (self , rebuild_dims : bool , inherit : bool ) -> DatasetView :
600- coord_vars = self ._coord_variables if inherit else self ._node_coord_variables
621+ def _to_dataset_view (
622+ self ,
623+ rebuild_dims : bool ,
624+ inherit : bool | Literal ["all_coords" , "indexes" ] = True ,
625+ ) -> DatasetView :
626+ coord_vars , indexes = self ._resolve_inherit (inherit )
601627 variables = dict (self ._data_variables )
602628 variables |= coord_vars
603629 if rebuild_dims :
@@ -636,10 +662,10 @@ def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
636662 dims = dict (self ._node_dims )
637663 return DatasetView ._constructor (
638664 variables = variables ,
639- coord_names = set (self . _coord_variables ),
665+ coord_names = set (coord_vars ),
640666 dims = dims ,
641667 attrs = self ._attrs ,
642- indexes = dict ( self . _indexes if inherit else self . _node_indexes ) ,
668+ indexes = indexes ,
643669 encoding = self ._encoding ,
644670 close = None ,
645671 )
@@ -669,30 +695,39 @@ def dataset(self, data: Dataset | None = None) -> None:
669695 # xarray-contrib/datatree
670696 ds = dataset
671697
672- def to_dataset (self , inherit : bool = True ) -> Dataset :
698+ def to_dataset (
699+ self , inherit : bool | Literal ["all_coords" , "indexes" ] = True
700+ ) -> Dataset :
673701 """
674702 Return the data in this node as a new xarray.Dataset object.
675703
676704 Parameters
677705 ----------
678- inherit : bool, optional
679- If False, only include coordinates and indexes defined at the level
680- of this DataTree node, excluding any inherited coordinates and indexes.
706+ inherit : bool or {"all_coords", "indexes"}, default True
707+ Controls which coordinates are inherited from parent nodes.
708+
709+ - True or "indexes": inherit only indexed coordinates (default).
710+ - "all_coords": inherit all coordinates, including non-index coordinates.
711+ - False: only include coordinates defined at this node.
681712
682713 See Also
683714 --------
684715 DataTree.dataset
685716 """
686- coord_vars = self ._coord_variables if inherit else self . _node_coord_variables
717+ coord_vars , indexes = self ._resolve_inherit ( inherit )
687718 variables = dict (self ._data_variables )
688719 variables |= coord_vars
689- dims = calculate_dimensions (variables ) if inherit else dict (self ._node_dims )
720+ dims = (
721+ dict (self ._node_dims )
722+ if inherit is False
723+ else calculate_dimensions (variables )
724+ )
690725 return Dataset ._construct_direct (
691726 variables ,
692727 set (coord_vars ),
693728 dims ,
694729 None if self ._attrs is None else dict (self ._attrs ),
695- dict ( self . _indexes if inherit else self . _node_indexes ) ,
730+ indexes ,
696731 None if self ._encoding is None else dict (self ._encoding ),
697732 None ,
698733 )
0 commit comments