|
6 | 6 | from array import array |
7 | 7 | from enum import Enum |
8 | 8 | from functools import partial |
| 9 | +from pickle import PickleBuffer |
9 | 10 | from types import ModuleType |
10 | 11 | from typing import Any, Generic, Literal, TypeVar |
11 | 12 |
|
@@ -86,19 +87,27 @@ def buffer_callback(f): |
86 | 87 | return header, frames |
87 | 88 |
|
88 | 89 |
|
89 | | -def pickle_loads(header, frames): |
90 | | - x, buffers = frames[0], frames[1:] |
| 90 | +def pickle_loads( |
| 91 | + header: dict[str, Any], frames: list[bytes | bytearray | memoryview | PickleBuffer] |
| 92 | +) -> Any: |
| 93 | + pik, buffers = frames[0], frames[1:] |
91 | 94 |
|
92 | | - writeable = header.get("writeable") |
93 | | - if not writeable: |
94 | | - writeable = len(buffers) * (None,) |
| 95 | + def ensure_writeable_flag(mv: memoryview, w: bool) -> memoryview: |
| 96 | + if w and mv.readonly: |
| 97 | + # Can't avoid a deep copy |
| 98 | + return memoryview(bytearray(mv)) |
| 99 | + elif not w and not mv.readonly: |
| 100 | + # Zero copy - this is just a flag |
| 101 | + return mv.toreadonly() |
| 102 | + else: |
| 103 | + return mv |
95 | 104 |
|
96 | 105 | buffers = [ |
97 | | - memoryview(bytearray(mv) if w else bytes(mv)) if w == mv.readonly else mv |
98 | | - for w, mv in zip(writeable, map(ensure_memoryview, buffers)) |
| 106 | + ensure_writeable_flag(ensure_memoryview(mv), w) |
| 107 | + for mv, w in zip(buffers, header["writeable"]) |
99 | 108 | ] |
100 | 109 |
|
101 | | - return pickle.loads(x, buffers=buffers) |
| 110 | + return pickle.loads(pik, buffers=buffers) |
102 | 111 |
|
103 | 112 |
|
104 | 113 | def import_allowed_module(name): |
|
0 commit comments