Skip to content

Commit 73c8cdd

Browse files
committed
fix: harden mp_actors proxy shutdown and add regression tests
Refactor proxy response handling to avoid re-entrant loop deadlocks and make shutdown resilient to cancelled-future races, while keeping child-process failures fail-fast. Add readable unit tests for proxy call paths, process-exit behavior, async-generator cancellation recovery, and close semantics. Made-with: Cursor
1 parent 32ff748 commit 73c8cdd

2 files changed

Lines changed: 256 additions & 90 deletions

File tree

src/mp_actors/move.py

Lines changed: 115 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import asyncio
2-
from concurrent.futures import ThreadPoolExecutor
2+
import atexit
3+
from concurrent.futures import Future, InvalidStateError
34
from dataclasses import dataclass
45
import inspect
56
import itertools as it
67
import multiprocessing as mp
78
import os
9+
import queue
810
import sys
11+
import threading
912
from typing import Any, AsyncGenerator, TypeVar, cast
13+
import weakref
1014

1115
import nest_asyncio
1216
import setproctitle
@@ -24,6 +28,27 @@
2428

2529
# Special ID to signal shutdown
2630
_SHUTDOWN_ID = -1
31+
_LIVE_PROXIES: weakref.WeakSet[Any] = weakref.WeakSet()
32+
_ATEXIT_REGISTERED = False
33+
34+
35+
def _close_all_live_proxies() -> None:
36+
# Best-effort cleanup for callers that forget close_proxy().
37+
for proxy in list(_LIVE_PROXIES):
38+
try:
39+
close = getattr(proxy, "close", None)
40+
if callable(close):
41+
close()
42+
except BaseException:
43+
pass
44+
45+
46+
def _register_proxy_for_atexit(proxy: "Proxy") -> None:
47+
global _ATEXIT_REGISTERED
48+
_LIVE_PROXIES.add(proxy)
49+
if not _ATEXIT_REGISTERED:
50+
atexit.register(_close_all_live_proxies)
51+
_ATEXIT_REGISTERED = True
2752

2853

2954
def move_to_child_process(
@@ -89,54 +114,66 @@ def __init__(
89114
args=(obj, self._requests, self._responses, log_file, process_name),
90115
)
91116
self._process.start()
92-
# dedicated executor for queue.get calls
93-
self._executor = ThreadPoolExecutor()
94-
self._futures: dict[int, asyncio.Future] = {}
95-
self._process_future = asyncio.Future()
96-
self._handle_responses_task = asyncio.create_task(self._handle_responses())
97-
self._monitor_task = asyncio.create_task(self._monitor_process())
117+
self._futures: dict[int, Future] = {}
118+
self._futures_lock = threading.Lock()
119+
self._dead_process_error: RuntimeError | None = None
120+
self._closing = False
121+
self._closed = False
98122
self._next_id = it.count(1).__next__
123+
self._dispatcher = threading.Thread(
124+
target=self._dispatch_responses, name="mp-actors-dispatch", daemon=True
125+
)
126+
self._dispatcher.start()
127+
_register_proxy_for_atexit(self)
128+
129+
def _process_error(self) -> RuntimeError:
130+
exit_code = self._process.exitcode
131+
name = f" '{self._process_name}'" if self._process_name else ""
132+
if exit_code is None:
133+
return RuntimeError(f"Child process{name} died unexpectedly")
134+
if exit_code < 0:
135+
return RuntimeError(
136+
f"Child process{name} was killed by signal {-exit_code}"
137+
)
138+
return RuntimeError(f"Child process{name} exited with code {exit_code}")
139+
140+
def _fail_pending(self, error: Exception) -> None:
141+
with self._futures_lock:
142+
pending = list(self._futures.values())
143+
self._futures.clear()
144+
for future in pending:
145+
if not future.done():
146+
try:
147+
future.set_exception(error)
148+
except InvalidStateError:
149+
pass
99150

100-
async def _handle_responses(self) -> None:
101-
loop = asyncio.get_event_loop()
151+
def _dispatch_responses(self) -> None:
102152
while True:
103-
response: Response = await loop.run_in_executor(
104-
self._executor, self._responses.get
105-
)
106-
# check for shutdown signal
153+
try:
154+
response: Response = self._responses.get(timeout=0.1)
155+
except queue.Empty:
156+
if self._closing:
157+
break
158+
if self._dead_process_error is None and not self._process.is_alive():
159+
self._dead_process_error = self._process_error()
160+
self._fail_pending(self._dead_process_error)
161+
continue
162+
except Exception:
163+
break
107164
if response.id == _SHUTDOWN_ID:
108165
break
109-
# normal processing
110-
future = self._futures.pop(response.id, None)
166+
with self._futures_lock:
167+
future = self._futures.pop(response.id, None)
111168
if future is None:
112169
continue
113-
if response.exception:
114-
future.set_exception(response.exception)
115-
else:
116-
future.set_result(response.result)
117-
118-
async def _monitor_process(self) -> None:
119-
"""Monitor the child process and set exception if it dies unexpectedly."""
120-
loop = asyncio.get_event_loop()
121-
while not self._process_future.done():
122-
is_alive = await loop.run_in_executor(None, self._process.is_alive)
123-
if not is_alive:
124-
if not self._process_future.done():
125-
exit_code = self._process.exitcode
126-
name = f" '{self._process_name}'" if self._process_name else ""
127-
if exit_code is None:
128-
exc = RuntimeError(f"Child process{name} died unexpectedly")
129-
elif exit_code < 0:
130-
exc = RuntimeError(
131-
f"Child process{name} was killed by signal {-exit_code}"
132-
)
133-
else:
134-
exc = RuntimeError(
135-
f"Child process{name} exited with code {exit_code}"
136-
)
137-
self._process_future.set_exception(exc)
138-
break
139-
await asyncio.sleep(0.1)
170+
try:
171+
if response.exception:
172+
future.set_exception(response.exception)
173+
else:
174+
future.set_result(response.result)
175+
except InvalidStateError:
176+
pass
140177

141178
@streamline_tracebacks()
142179
def __getattr__(self, name: str) -> Any:
@@ -146,26 +183,33 @@ def __getattr__(self, name: str) -> Any:
146183
f"{type(self._obj).__name__} has no attribute '{name}'"
147184
)
148185

149-
async def get_response(
186+
def response_future(
150187
args: tuple[Any, ...],
151188
kwargs: dict[str, Any],
152189
id: int | None = None,
153190
send_value: Any | None = None,
154-
) -> Any:
191+
) -> Future:
155192
request = Request(
156193
id=id if id is not None else self._next_id(),
157194
method_name=name,
158195
args=args,
159196
kwargs=kwargs,
160197
send_value=send_value,
161198
)
162-
self._futures[request.id] = asyncio.Future()
163-
self._requests.put_nowait(request)
164-
done, _ = await asyncio.wait(
165-
[self._futures[request.id], self._process_future],
166-
return_when=asyncio.FIRST_COMPLETED,
167-
)
168-
return done.pop().result()
199+
future: Future = Future()
200+
with self._futures_lock:
201+
if self._dead_process_error:
202+
raise self._dead_process_error
203+
if self._closing:
204+
raise RuntimeError("Proxy is closing")
205+
self._futures[request.id] = future
206+
try:
207+
self._requests.put_nowait(request)
208+
except BaseException:
209+
with self._futures_lock:
210+
self._futures.pop(request.id, None)
211+
raise
212+
return future
169213

170214
# Check if it's a method or property
171215
attr = getattr(self._obj, name)
@@ -179,8 +223,8 @@ async def async_gen_wrapper(
179223
id = self._next_id()
180224
send_value = None
181225
while True:
182-
send_value = yield await get_response(
183-
args, kwargs, id, send_value
226+
send_value = yield await asyncio.wrap_future(
227+
response_future(args, kwargs, id, send_value)
184228
)
185229
args, kwargs = (), {}
186230
except StopAsyncIteration:
@@ -191,62 +235,41 @@ async def async_gen_wrapper(
191235
# Return an async wrapper function
192236
@streamline_tracebacks()
193237
async def async_method_wrapper(*args: Any, **kwargs: Any) -> Any:
194-
return await get_response(args, kwargs)
238+
return await asyncio.wrap_future(response_future(args, kwargs))
195239

196240
return async_method_wrapper
197241
elif callable(attr):
198242
# Return a regular function wrapper
199243
@streamline_tracebacks()
200244
def method_wrapper(*args: Any, **kwargs: Any) -> Any:
201-
return asyncio.run(get_response(args, kwargs))
245+
return response_future(args, kwargs).result()
202246

203247
return method_wrapper
204248
else:
205249
# For non-callable attributes, get them directly
206-
return asyncio.run(get_response(tuple(), dict()))
250+
return response_future(tuple(), dict()).result()
207251

208252
def close(self):
209-
# Cancel monitoring to avoid false alarms during shutdown
210-
if not self._process_future.done():
211-
self._process_future.cancel()
212-
if hasattr(self, "_monitor_task"):
213-
self._monitor_task.cancel()
214-
215-
# signal the response loop to exit
216-
self._responses.put_nowait(Response(_SHUTDOWN_ID, None, None))
217-
# wait for the handler to finish
218-
if hasattr(self, "_handle_responses_task"):
219-
# give it a moment to break
220-
try:
221-
asyncio.get_event_loop().run_until_complete(self._handle_responses_task)
222-
except Exception:
223-
pass
253+
if self._closed:
254+
return
255+
self._closed = True
256+
_LIVE_PROXIES.discard(self)
257+
self._closing = True
258+
self._fail_pending(RuntimeError("Proxy is closing"))
224259

225260
# terminate child process and force kill if needed
226-
if hasattr(self, "_process"):
227-
self._process.terminate()
228-
try:
229-
self._process.join(timeout=1)
230-
except Exception:
231-
pass
232-
if self._process.is_alive():
233-
# Python 3.7+: force kill
234-
try:
235-
self._process.kill()
236-
except AttributeError:
237-
# fallback: os.kill
238-
if self._process.pid:
239-
os.kill(self._process.pid, 9)
240-
self._process.join()
241-
242-
# shutdown executor cleanly
243-
self._executor.shutdown(wait=True)
261+
self._process.terminate()
262+
self._process.join(timeout=1)
263+
if self._process.is_alive():
264+
self._process.kill()
265+
self._process.join(timeout=1)
244266

245267
# close and cancel queue feeder threads
246268
self._responses.close()
247269
self._responses.cancel_join_thread()
248270
self._requests.close()
249271
self._requests.cancel_join_thread()
272+
self._dispatcher.join(timeout=1)
250273

251274

252275
def _target(
@@ -272,7 +295,9 @@ async def _handle_requests(
272295
request: Request = await asyncio.get_event_loop().run_in_executor(
273296
None, requests.get
274297
)
275-
asyncio.create_task(_handle_request(obj, request, responses, generators))
298+
asyncio.create_task(
299+
_handle_request(obj, request, responses, generators)
300+
).add_done_callback(lambda t: None if t.cancelled() else t.exception())
276301

277302

278303
async def _handle_request(

0 commit comments

Comments
 (0)