Skip to content

Commit f621c65

Browse files
authored
Ensure inproc properly emulates serialization protocol (#8622)
1 parent 3f13a2d commit f621c65

6 files changed

Lines changed: 42 additions & 16 deletions

File tree

distributed/comm/inproc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector
1515
from distributed.comm.registry import Backend, backends
16-
from distributed.protocol import nested_deserialize
16+
from distributed.protocol.serialize import _nested_deserialize
1717
from distributed.utils import get_ip, is_python_shutting_down
1818

1919
logger = logging.getLogger(__name__)
@@ -218,8 +218,7 @@ async def read(self, deserializers="ignored"):
218218
self._finalizer.detach()
219219
raise CommClosedError()
220220

221-
if self.deserialize:
222-
msg = nested_deserialize(msg)
221+
msg = _nested_deserialize(msg, self.deserialize)
223222
return msg
224223

225224
async def write(self, msg, serializers=None, on_error=None):

distributed/protocol/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from distributed.protocol.core import decompress, dumps, loads, maybe_compress, msgpack
77
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
88
from distributed.protocol.serialize import (
9+
Pickled,
910
Serialize,
1011
Serialized,
12+
ToPickle,
1113
dask_deserialize,
1214
dask_serialize,
1315
deserialize,

distributed/protocol/serialize.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import codecs
44
import importlib
55
import traceback
6+
import warnings
67
from array import array
78
from enum import Enum
89
from functools import partial
@@ -621,6 +622,14 @@ def __ne__(self, other):
621622

622623

623624
def nested_deserialize(x):
625+
warnings.warn(
626+
"nested_deserialize is deprecated and will be removed in a future release.",
627+
DeprecationWarning,
628+
)
629+
return _nested_deserialize(x, emulate_deserialize=True)
630+
631+
632+
def _nested_deserialize(x, emulate_deserialize=True):
624633
"""
625634
Replace all Serialize and Serialized values nested in *x*
626635
with the original values. Returns a copy of *x*.
@@ -637,21 +646,27 @@ def replace_inner(x):
637646
typ = type(v)
638647
if typ is dict or typ is list:
639648
x[k] = replace_inner(v)
640-
elif typ is Serialize:
649+
if emulate_deserialize:
650+
if typ is Serialize:
651+
x[k] = v.data
652+
elif typ is Serialized:
653+
x[k] = deserialize(v.header, v.frames)
654+
if typ is ToPickle:
641655
x[k] = v.data
642-
elif typ is Serialized:
643-
x[k] = deserialize(v.header, v.frames)
644656

645657
elif type(x) is list:
646658
x = list(x)
647659
for k, v in enumerate(x):
648660
typ = type(v)
649661
if typ is dict or typ is list:
650662
x[k] = replace_inner(v)
651-
elif typ is Serialize:
663+
if emulate_deserialize:
664+
if typ is Serialize:
665+
x[k] = v.data
666+
elif typ is Serialized:
667+
x[k] = deserialize(v.header, v.frames)
668+
if typ is ToPickle:
652669
x[k] = v.data
653-
elif typ is Serialized:
654-
x[k] = deserialize(v.header, v.frames)
655670

656671
return x
657672

distributed/protocol/tests/test_serialize.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
from distributed.protocol import (
2121
Serialize,
2222
Serialized,
23+
ToPickle,
2324
dask_serialize,
2425
deserialize,
2526
deserialize_bytes,
2627
dumps,
2728
loads,
28-
nested_deserialize,
2929
register_serialization,
3030
register_serialization_family,
3131
serialize,
@@ -35,6 +35,7 @@
3535
)
3636
from distributed.protocol.serialize import (
3737
_is_msgpack_serializable,
38+
_nested_deserialize,
3839
check_dask_serializable,
3940
)
4041
from distributed.utils import ensure_memoryview, nbytes
@@ -166,12 +167,24 @@ def test_nested_deserialize():
166167
"x": [to_serialize(123), to_serialize(456), 789],
167168
"y": {"a": ["abc", Serialized(*serialize("def"))], "b": b"ghi"},
168169
}
170+
171+
x_orig = copy.deepcopy(x)
172+
assert _nested_deserialize(x, emulate_deserialize=False) == x_orig
173+
174+
assert x == x_orig # x wasn't mutated
175+
x["topickle"] = ToPickle(1)
176+
x["topickle_nested"] = [1, ToPickle(2)]
169177
x_orig = copy.deepcopy(x)
178+
assert (out := _nested_deserialize(x, emulate_deserialize=False)) != x_orig
179+
assert out["topickle"] == 1
180+
assert out["topickle_nested"] == [1, 2]
170181

171-
assert nested_deserialize(x) == {
182+
assert _nested_deserialize(x) == {
172183
"op": "update",
173184
"x": [123, 456, 789],
174185
"y": {"a": ["abc", "def"], "b": b"ghi"},
186+
"topickle": 1,
187+
"topickle_nested": [1, 2],
175188
}
176189
assert x == x_orig # x wasn't mutated
177190

distributed/scheduler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4676,9 +4676,6 @@ async def update_graph(
46764676
annotations: dict | None = None,
46774677
stimulus_id: str | None = None,
46784678
) -> None:
4679-
# FIXME: Apparently empty dicts arrive as a ToPickle object
4680-
if isinstance(annotations, ToPickle):
4681-
annotations = annotations.data # type: ignore[unreachable]
46824679
start = time()
46834680
try:
46844681
try:

distributed/shuffle/tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def __init__(self, shuffle: ShuffleRun):
2222

2323
def __getattr__(self, key):
2424
async def _(**kwargs):
25-
from distributed.protocol.serialize import nested_deserialize
25+
from distributed.protocol.serialize import _nested_deserialize
2626

2727
method_name = key.replace("shuffle_", "")
2828
kwargs.pop("shuffle_id", None)
2929
kwargs.pop("run_id", None)
3030
# TODO: This is a bit awkward. At some point the arguments are
3131
# already getting wrapped with a `Serialize`. We only want to unwrap
3232
# here.
33-
kwargs = nested_deserialize(kwargs)
33+
kwargs = _nested_deserialize(kwargs)
3434
meth = getattr(self.shuffle, method_name)
3535
return await meth(**kwargs)
3636

0 commit comments

Comments
 (0)