Skip to content

Commit 9e7b276

Browse files
authored
Don't deep-copy read-only buffers on unpickle (#8609)
1 parent 5647d06 commit 9e7b276

2 files changed

Lines changed: 35 additions & 8 deletions

File tree

distributed/protocol/serialize.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from array import array
77
from enum import Enum
88
from functools import partial
9+
from pickle import PickleBuffer
910
from types import ModuleType
1011
from typing import Any, Generic, Literal, TypeVar
1112

@@ -86,19 +87,27 @@ def buffer_callback(f):
8687
return header, frames
8788

8889

89-
def pickle_loads(header, frames):
90-
x, buffers = frames[0], frames[1:]
90+
def pickle_loads(
91+
header: dict[str, Any], frames: list[bytes | bytearray | memoryview | PickleBuffer]
92+
) -> Any:
93+
pik, buffers = frames[0], frames[1:]
9194

92-
writeable = header.get("writeable")
93-
if not writeable:
94-
writeable = len(buffers) * (None,)
95+
def ensure_writeable_flag(mv: memoryview, w: bool) -> memoryview:
96+
if w and mv.readonly:
97+
# Can't avoid a deep copy
98+
return memoryview(bytearray(mv))
99+
elif not w and not mv.readonly:
100+
# Zero copy - this is just a flag
101+
return mv.toreadonly()
102+
else:
103+
return mv
95104

96105
buffers = [
97-
memoryview(bytearray(mv) if w else bytes(mv)) if w == mv.readonly else mv
98-
for w, mv in zip(writeable, map(ensure_memoryview, buffers))
106+
ensure_writeable_flag(ensure_memoryview(mv), w)
107+
for mv, w in zip(buffers, header["writeable"])
99108
]
100109

101-
return pickle.loads(x, buffers=buffers)
110+
return pickle.loads(pik, buffers=buffers)
102111

103112

104113
def import_allowed_module(name):

distributed/protocol/tests/test_pickle.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,21 @@ def func(df):
297297
out, _ = proc.communicate(timeout=60)
298298

299299
assert "script successful" in out.decode("utf-8")
300+
301+
302+
@pytest.mark.parametrize("serializer", ["dask", "pickle"])
303+
def test_pickle_zero_copy_read_only_flag(serializer):
304+
np = pytest.importorskip("numpy")
305+
a = np.arange(10)
306+
a.flags.writeable = False
307+
header, frames = serialize(a, serializers=[serializer])
308+
frames = [bytearray(f) for f in frames] # Simulate network transfer
309+
b = deserialize(header, frames)
310+
c = deserialize(header, frames)
311+
assert not b.flags.writeable
312+
assert not c.flags.writeable
313+
ptr_a = a.__array_interface__["data"][0]
314+
ptr_b = b.__array_interface__["data"][0]
315+
ptr_c = c.__array_interface__["data"][0]
316+
assert ptr_b != ptr_a
317+
assert ptr_b == ptr_c

0 commit comments

Comments
 (0)