11import asyncio
2- from concurrent .futures import ThreadPoolExecutor
2+ import atexit
3+ from concurrent .futures import Future , InvalidStateError
34from dataclasses import dataclass
45import inspect
56import itertools as it
67import multiprocessing as mp
78import os
9+ import queue
810import sys
11+ import threading
912from typing import Any , AsyncGenerator , TypeVar , cast
13+ import weakref
1014
1115import nest_asyncio
1216import setproctitle
2428
2529# Special ID to signal shutdown
2630_SHUTDOWN_ID = - 1
31+ _LIVE_PROXIES : weakref .WeakSet [Any ] = weakref .WeakSet ()
32+ _ATEXIT_REGISTERED = False
33+
34+
35+ def _close_all_live_proxies () -> None :
36+ # Best-effort cleanup for callers that forget close_proxy().
37+ for proxy in list (_LIVE_PROXIES ):
38+ try :
39+ close = getattr (proxy , "close" , None )
40+ if callable (close ):
41+ close ()
42+ except BaseException :
43+ pass
44+
45+
46+ def _register_proxy_for_atexit (proxy : "Proxy" ) -> None :
47+ global _ATEXIT_REGISTERED
48+ _LIVE_PROXIES .add (proxy )
49+ if not _ATEXIT_REGISTERED :
50+ atexit .register (_close_all_live_proxies )
51+ _ATEXIT_REGISTERED = True
2752
2853
2954def move_to_child_process (
@@ -89,54 +114,66 @@ def __init__(
89114 args = (obj , self ._requests , self ._responses , log_file , process_name ),
90115 )
91116 self ._process .start ()
92- # dedicated executor for queue.get calls
93- self ._executor = ThreadPoolExecutor ()
94- self ._futures : dict [int , asyncio .Future ] = {}
95- self ._process_future = asyncio .Future ()
96- self ._handle_responses_task = asyncio .create_task (self ._handle_responses ())
97- self ._monitor_task = asyncio .create_task (self ._monitor_process ())
117+ self ._futures : dict [int , Future ] = {}
118+ self ._futures_lock = threading .Lock ()
119+ self ._dead_process_error : RuntimeError | None = None
120+ self ._closing = False
121+ self ._closed = False
98122 self ._next_id = it .count (1 ).__next__
123+ self ._dispatcher = threading .Thread (
124+ target = self ._dispatch_responses , name = "mp-actors-dispatch" , daemon = True
125+ )
126+ self ._dispatcher .start ()
127+ _register_proxy_for_atexit (self )
128+
129+ def _process_error (self ) -> RuntimeError :
130+ exit_code = self ._process .exitcode
131+ name = f" '{ self ._process_name } '" if self ._process_name else ""
132+ if exit_code is None :
133+ return RuntimeError (f"Child process{ name } died unexpectedly" )
134+ if exit_code < 0 :
135+ return RuntimeError (
136+ f"Child process{ name } was killed by signal { - exit_code } "
137+ )
138+ return RuntimeError (f"Child process{ name } exited with code { exit_code } " )
139+
140+ def _fail_pending (self , error : Exception ) -> None :
141+ with self ._futures_lock :
142+ pending = list (self ._futures .values ())
143+ self ._futures .clear ()
144+ for future in pending :
145+ if not future .done ():
146+ try :
147+ future .set_exception (error )
148+ except InvalidStateError :
149+ pass
99150
100- async def _handle_responses (self ) -> None :
101- loop = asyncio .get_event_loop ()
151+ def _dispatch_responses (self ) -> None :
102152 while True :
103- response : Response = await loop .run_in_executor (
104- self ._executor , self ._responses .get
105- )
106- # check for shutdown signal
153+ try :
154+ response : Response = self ._responses .get (timeout = 0.1 )
155+ except queue .Empty :
156+ if self ._closing :
157+ break
158+ if self ._dead_process_error is None and not self ._process .is_alive ():
159+ self ._dead_process_error = self ._process_error ()
160+ self ._fail_pending (self ._dead_process_error )
161+ continue
162+ except Exception :
163+ break
107164 if response .id == _SHUTDOWN_ID :
108165 break
109- # normal processing
110- future = self ._futures .pop (response .id , None )
166+ with self . _futures_lock :
167+ future = self ._futures .pop (response .id , None )
111168 if future is None :
112169 continue
113- if response .exception :
114- future .set_exception (response .exception )
115- else :
116- future .set_result (response .result )
117-
118- async def _monitor_process (self ) -> None :
119- """Monitor the child process and set exception if it dies unexpectedly."""
120- loop = asyncio .get_event_loop ()
121- while not self ._process_future .done ():
122- is_alive = await loop .run_in_executor (None , self ._process .is_alive )
123- if not is_alive :
124- if not self ._process_future .done ():
125- exit_code = self ._process .exitcode
126- name = f" '{ self ._process_name } '" if self ._process_name else ""
127- if exit_code is None :
128- exc = RuntimeError (f"Child process{ name } died unexpectedly" )
129- elif exit_code < 0 :
130- exc = RuntimeError (
131- f"Child process{ name } was killed by signal { - exit_code } "
132- )
133- else :
134- exc = RuntimeError (
135- f"Child process{ name } exited with code { exit_code } "
136- )
137- self ._process_future .set_exception (exc )
138- break
139- await asyncio .sleep (0.1 )
170+ try :
171+ if response .exception :
172+ future .set_exception (response .exception )
173+ else :
174+ future .set_result (response .result )
175+ except InvalidStateError :
176+ pass
140177
141178 @streamline_tracebacks ()
142179 def __getattr__ (self , name : str ) -> Any :
@@ -146,26 +183,33 @@ def __getattr__(self, name: str) -> Any:
146183 f"{ type (self ._obj ).__name__ } has no attribute '{ name } '"
147184 )
148185
149- async def get_response (
186+ def response_future (
150187 args : tuple [Any , ...],
151188 kwargs : dict [str , Any ],
152189 id : int | None = None ,
153190 send_value : Any | None = None ,
154- ) -> Any :
191+ ) -> Future :
155192 request = Request (
156193 id = id if id is not None else self ._next_id (),
157194 method_name = name ,
158195 args = args ,
159196 kwargs = kwargs ,
160197 send_value = send_value ,
161198 )
162- self ._futures [request .id ] = asyncio .Future ()
163- self ._requests .put_nowait (request )
164- done , _ = await asyncio .wait (
165- [self ._futures [request .id ], self ._process_future ],
166- return_when = asyncio .FIRST_COMPLETED ,
167- )
168- return done .pop ().result ()
199+ future : Future = Future ()
200+ with self ._futures_lock :
201+ if self ._dead_process_error :
202+ raise self ._dead_process_error
203+ if self ._closing :
204+ raise RuntimeError ("Proxy is closing" )
205+ self ._futures [request .id ] = future
206+ try :
207+ self ._requests .put_nowait (request )
208+ except BaseException :
209+ with self ._futures_lock :
210+ self ._futures .pop (request .id , None )
211+ raise
212+ return future
169213
170214 # Check if it's a method or property
171215 attr = getattr (self ._obj , name )
@@ -179,8 +223,8 @@ async def async_gen_wrapper(
179223 id = self ._next_id ()
180224 send_value = None
181225 while True :
182- send_value = yield await get_response (
183- args , kwargs , id , send_value
226+ send_value = yield await asyncio . wrap_future (
227+ response_future ( args , kwargs , id , send_value )
184228 )
185229 args , kwargs = (), {}
186230 except StopAsyncIteration :
@@ -191,62 +235,41 @@ async def async_gen_wrapper(
191235 # Return an async wrapper function
192236 @streamline_tracebacks ()
193237 async def async_method_wrapper (* args : Any , ** kwargs : Any ) -> Any :
194- return await get_response ( args , kwargs )
238+ return await asyncio . wrap_future ( response_future ( args , kwargs ) )
195239
196240 return async_method_wrapper
197241 elif callable (attr ):
198242 # Return a regular function wrapper
199243 @streamline_tracebacks ()
200244 def method_wrapper (* args : Any , ** kwargs : Any ) -> Any :
201- return asyncio . run ( get_response ( args , kwargs ))
245+ return response_future ( args , kwargs ). result ( )
202246
203247 return method_wrapper
204248 else :
205249 # For non-callable attributes, get them directly
206- return asyncio . run ( get_response ( tuple (), dict ()))
250+ return response_future ( tuple (), dict ()). result ( )
207251
208252 def close (self ):
209- # Cancel monitoring to avoid false alarms during shutdown
210- if not self ._process_future .done ():
211- self ._process_future .cancel ()
212- if hasattr (self , "_monitor_task" ):
213- self ._monitor_task .cancel ()
214-
215- # signal the response loop to exit
216- self ._responses .put_nowait (Response (_SHUTDOWN_ID , None , None ))
217- # wait for the handler to finish
218- if hasattr (self , "_handle_responses_task" ):
219- # give it a moment to break
220- try :
221- asyncio .get_event_loop ().run_until_complete (self ._handle_responses_task )
222- except Exception :
223- pass
253+ if self ._closed :
254+ return
255+ self ._closed = True
256+ _LIVE_PROXIES .discard (self )
257+ self ._closing = True
258+ self ._fail_pending (RuntimeError ("Proxy is closing" ))
224259
225260 # terminate child process and force kill if needed
226- if hasattr (self , "_process" ):
227- self ._process .terminate ()
228- try :
229- self ._process .join (timeout = 1 )
230- except Exception :
231- pass
232- if self ._process .is_alive ():
233- # Python 3.7+: force kill
234- try :
235- self ._process .kill ()
236- except AttributeError :
237- # fallback: os.kill
238- if self ._process .pid :
239- os .kill (self ._process .pid , 9 )
240- self ._process .join ()
241-
242- # shutdown executor cleanly
243- self ._executor .shutdown (wait = True )
261+ self ._process .terminate ()
262+ self ._process .join (timeout = 1 )
263+ if self ._process .is_alive ():
264+ self ._process .kill ()
265+ self ._process .join (timeout = 1 )
244266
245267 # close and cancel queue feeder threads
246268 self ._responses .close ()
247269 self ._responses .cancel_join_thread ()
248270 self ._requests .close ()
249271 self ._requests .cancel_join_thread ()
272+ self ._dispatcher .join (timeout = 1 )
250273
251274
252275def _target (
@@ -272,7 +295,9 @@ async def _handle_requests(
272295 request : Request = await asyncio .get_event_loop ().run_in_executor (
273296 None , requests .get
274297 )
275- asyncio .create_task (_handle_request (obj , request , responses , generators ))
298+ asyncio .create_task (
299+ _handle_request (obj , request , responses , generators )
300+ ).add_done_callback (lambda t : None if t .cancelled () else t .exception ())
276301
277302
278303async def _handle_request (
0 commit comments