Skip to content

Commit 6c36d81

Browse files
kkollsgaclaudejsignell
authored
Allow writing StringDType variables to netCDF (#11218)
* Allow writing StringDType variables to netCDF (#11199) Recognize numpy.dtypes.StringDType (kind "T") as a unicode string type in is_unicode_dtype, and convert StringDType arrays to object arrays before passing to netCDF4/h5netcdf backends which don't support StringDType natively. Null values from StringDType(na_object=None) are replaced with empty strings on write. Co-authored-by: Claude <noreply@anthropic.com> * Add StringDType null handling for scipy and tests for review feedback - Handle StringDType null values in encode_string_array (scipy/nc3 path) - Add roundtrip tests for StringDType with na_object=None and na_object="" - Add unit test for encode_string_array with StringDType nulls Co-Authored-By: Claude <noreply@anthropic.com> * Move StringDType handling into shared encoder (#11199) Convert StringDType to fixed-width unicode (U) in EncodedStringCoder.encode() instead of per-backend prepare_variable, fixing Zarr and CFEncodedDataStore. Co-authored-by: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Julia Signell <jsignell@gmail.com>
1 parent d937ce6 commit 6c36d81

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ Deprecations
3939
Bug Fixes
4040
~~~~~~~~~
4141

42+
- Allow writing ``StringDType`` variables to netCDF files (:issue:`11199`).
43+
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
4244
- Fix ``Source`` link in api docs (:pull:`11187`)
4345
By `Ian Hunt-Isaak <https://github.com/ianhi>`_
4446
- Coerce masked dask arrays to filled (:issue:`9374` :pull:`11157`).

xarray/coding/strings.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def check_vlen_dtype(dtype):
4040

4141

4242
def is_unicode_dtype(dtype):
43-
return dtype.kind == "U" or check_vlen_dtype(dtype) is str
43+
return dtype.kind in ("U", "T") or check_vlen_dtype(dtype) is str
4444

4545

4646
def is_bytes_dtype(dtype):
@@ -56,6 +56,14 @@ def __init__(self, allows_unicode=True):
5656
def encode(self, variable: Variable, name=None) -> Variable:
5757
dims, data, attrs, encoding = unpack_for_encoding(variable)
5858

59+
# StringDType: replace nulls and convert to fixed-width unicode (U),
60+
# which all backends support natively (GH11199)
61+
if data.dtype.kind == "T":
62+
data = np.asarray(data, dtype=object)
63+
data[data == None] = "" # noqa: E711
64+
data = np.asarray(data, dtype="U")
65+
variable = Variable(dims, data, attrs, encoding)
66+
5967
contains_unicode = is_unicode_dtype(data.dtype)
6068
encode_as_char = encoding.get("dtype") == "S1"
6169
if encode_as_char:

xarray/tests/test_backends.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,31 @@ def test_roundtrip_string_data(self) -> None:
701701
with self.roundtrip(expected) as actual:
702702
assert_identical(expected, actual)
703703

704+
@pytest.mark.skipif(not HAS_STRING_DTYPE, reason="requires StringDType")
705+
def test_roundtrip_stringdtype_data(self) -> None:
706+
# GH11199
707+
data = np.array(["ab", "cdef"], dtype=np.dtypes.StringDType())
708+
expected = Dataset({"x": ("t", data)})
709+
with self.roundtrip(expected) as actual:
710+
assert_identical(expected, actual)
711+
712+
@pytest.mark.skipif(not HAS_STRING_DTYPE, reason="requires StringDType")
713+
def test_roundtrip_stringdtype_nulls(self) -> None:
714+
# GH11199 — null values in StringDType are written as empty strings
715+
data = np.array(["ab", None], dtype=np.dtypes.StringDType(na_object=None))
716+
ds = Dataset({"x": ("t", data)})
717+
with self.roundtrip(ds) as actual:
718+
expected = Dataset({"x": ("t", np.array(["ab", ""]))})
719+
assert_identical(expected, actual)
720+
721+
@pytest.mark.skipif(not HAS_STRING_DTYPE, reason="requires StringDType")
722+
def test_roundtrip_stringdtype_with_na_object(self) -> None:
723+
# GH11199 — StringDType(na_object="") should roundtrip correctly
724+
data = np.array(["ab", "cdef"], dtype=np.dtypes.StringDType(na_object=""))
725+
expected = Dataset({"x": ("t", data)})
726+
with self.roundtrip(expected) as actual:
727+
assert_identical(expected, actual)
728+
704729
def test_roundtrip_string_encoded_characters(self) -> None:
705730
expected = Dataset({"x": ("t", ["ab", "cdef"])})
706731
expected["x"].encoding["dtype"] = "S1"

xarray/tests/test_coding_strings.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ def test_vlen_dtype() -> None:
3939
assert strings.check_vlen_dtype(np.dtype(object)) is None
4040

4141

42+
@pytest.mark.skipif(
43+
not hasattr(np.dtypes, "StringDType"), reason="requires StringDType"
44+
)
45+
def test_is_unicode_dtype_stringdtype() -> None:
46+
# GH11199
47+
dtype = np.dtypes.StringDType()
48+
assert strings.is_unicode_dtype(dtype)
49+
assert not strings.is_bytes_dtype(dtype)
50+
51+
4252
@pytest.mark.parametrize("numpy_str_type", (np.str_, np.bytes_))
4353
def test_numpy_subclass_handling(numpy_str_type) -> None:
4454
with pytest.raises(TypeError, match="unsupported type for vlen_dtype"):
@@ -93,6 +103,20 @@ def test_EncodedStringCoder_encode() -> None:
93103
assert_identical(coder.encode(raw), expected)
94104

95105

106+
@pytest.mark.skipif(
107+
not hasattr(np.dtypes, "StringDType"), reason="requires StringDType"
108+
)
109+
def test_encoded_string_coder_stringdtype_nulls() -> None:
110+
# GH11199 — EncodedStringCoder normalizes StringDType nulls to empty strings
111+
data = np.array(["ab", None], dtype=np.dtypes.StringDType(na_object=None))
112+
var = Variable("x", data)
113+
coder = strings.EncodedStringCoder(allows_unicode=True)
114+
result = coder.encode(var)
115+
expected = Variable("x", np.array(["ab", ""]))
116+
assert_identical(result, expected)
117+
assert result.dtype.kind == "U"
118+
119+
96120
@pytest.mark.parametrize(
97121
"original",
98122
[

0 commit comments

Comments
 (0)