Skip to content

Commit a8732fc

Browse files
authored
Fix encoding of masked dask arrays (#11157)
* Coerce masked dask arrays to filled when constructing variables * Use `where_method` and let it promote the dtype of all masked arrays
1 parent 3ec5a38 commit a8732fc

5 files changed

Lines changed: 37 additions & 8 deletions

File tree

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Deprecations
2626
Bug Fixes
2727
~~~~~~~~~
2828

29+
- Coerce masked dask arrays to filled (:issue:`9374` :pull:`11157`).
30+
By `Julia Signell <https://github.com/jsignell>`_
2931

3032
Documentation
3133
~~~~~~~~~~~~~

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
127127
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
128128
)
129129

130+
getmaskarray = _dask_or_eager_func(
131+
"getmaskarray", eager_module=np.ma, dask_module="dask.array.ma"
132+
)
133+
130134

131135
def sliding_window_view(array, window_shape, axis=None, **kwargs):
132136
# TODO: some libraries (e.g. jax) don't have this, implement an alternative?

xarray/core/variable.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
5151
from xarray.namedarray.parallelcompat import get_chunked_array_type
5252
from xarray.namedarray.pycompat import (
53+
array_type,
5354
async_to_duck_array,
5455
integer_types,
5556
is_0d_dask_array,
@@ -292,13 +293,12 @@ def convert_non_numpy_type(data):
292293
else:
293294
data = pandas_data
294295

295-
if isinstance(data, np.ma.MaskedArray):
296-
mask = np.ma.getmaskarray(data)
297-
if mask.any():
298-
_dtype, fill_value = dtypes.maybe_promote(data.dtype)
299-
data = duck_array_ops.where_method(data, ~mask, fill_value)
300-
else:
301-
data = np.asarray(data)
296+
if isinstance(data, np.ma.MaskedArray) or (
297+
isinstance(data, array_type("dask"))
298+
and isinstance(getattr(data, "_meta", None), np.ma.MaskedArray)
299+
):
300+
mask = duck_array_ops.getmaskarray(data)
301+
data = duck_array_ops.where_method(data, ~mask)
302302

303303
if isinstance(data, np.matrix):
304304
data = np.asarray(data)

xarray/tests/test_backends.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
mock,
8686
network,
8787
parametrize_zarr_format,
88+
raise_if_dask_computes,
8889
requires_cftime,
8990
requires_dask,
9091
requires_fsspec,
@@ -2231,6 +2232,28 @@ def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None:
22312232
else:
22322233
assert len(loaded_ds.xindexes) == 0
22332234

2235+
@requires_dask
2236+
def test_encoding_masked_arrays(self, tmp_path) -> None:
2237+
store_path = tmp_path / "tmp.nc"
2238+
2239+
with raise_if_dask_computes():
2240+
ds = xr.DataArray(
2241+
dask.array.from_array(
2242+
np.ma.masked_array(
2243+
np.array([[np.nan, np.nan], [np.nan, 2]]),
2244+
np.array([[True, True], [True, False]]),
2245+
)
2246+
).astype("float32"),
2247+
dims=("x", "y"),
2248+
).to_dataset(name="mydata")
2249+
2250+
expected = ds.mean("x")
2251+
expected.to_netcdf(
2252+
store_path, encoding=dict(mydata=dict(_FillValue=np.float32(1e20)))
2253+
)
2254+
with open_dataset(store_path, engine=self.engine) as actual:
2255+
assert_identical(expected.compute(), actual.compute())
2256+
22342257

22352258
@requires_netCDF4
22362259
class TestNetCDF4Data(NetCDF4Base):

xarray/tests/test_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2757,7 +2757,7 @@ def test_masked_array(self):
27572757
expected = np.arange(5)
27582758
actual: Any = as_compatible_data(original)
27592759
assert_array_equal(expected, actual)
2760-
assert np.dtype(int) == actual.dtype
2760+
assert np.dtype(float) == actual.dtype
27612761

27622762
original1: Any = np.ma.MaskedArray(np.arange(5), mask=4 * [False] + [True])
27632763
expected1: Any = np.arange(5.0)

0 commit comments

Comments
 (0)