1414from dask import delayed
1515from dask .utils import parse_bytes
1616
17- from distributed import Client , Nanny , profile , wait
17+ from distributed import Client , KilledWorker , Nanny , get_worker , profile , wait
1818from distributed .comm import CommClosedError
1919from distributed .compatibility import MACOS
20+ from distributed .core import Status
2021from distributed .metrics import time
2122from distributed .utils import CancelledError , sync
2223from distributed .utils_test import (
@@ -450,10 +451,10 @@ async def test_restart_timeout_on_long_running_task(c, s, a):
450451
451452
452453@pytest .mark .slow
453- @gen_cluster (client = True , scheduler_kwargs = {"worker_ttl " : "500ms" })
454+ @gen_cluster (client = True , config = {"distributed.scheduler.worker-ttl " : "500ms" })
454455async def test_worker_time_to_live (c , s , a , b ):
455- from distributed . scheduler import heartbeat_interval
456-
456+ # Note that this value is ignored because is less than 10x heartbeat_interval
457+ assert s . worker_ttl == 0.5
457458 assert set (s .workers ) == {a .address , b .address }
458459
459460 a .periodic_callbacks ["heartbeat" ].stop ()
@@ -465,10 +466,84 @@ async def test_worker_time_to_live(c, s, a, b):
465466
466467 # Worker removal is triggered after 10 * heartbeat
467468 # This is 10 * 0.5s at the moment of writing.
468- interval = 10 * heartbeat_interval (len (s .workers ))
469469 # Currently observing an extra 0.3~0.6s on top of the interval.
470470 # Adding some padding to prevent flakiness.
471- assert time () - start < interval + 2.0
471+ assert time () - start < 7
472+
473+
474+ @pytest .mark .slow
475+ @pytest .mark .parametrize ("block_evloop" , [False , True ])
476+ @gen_cluster (
477+ client = True ,
478+ Worker = Nanny ,
479+ nthreads = [("" , 1 )],
480+ scheduler_kwargs = {"worker_ttl" : "500ms" , "allowed_failures" : 0 },
481+ )
482+ async def test_worker_ttl_restarts_worker (c , s , a , block_evloop ):
483+ """If the event loop of a worker becomes completely unresponsive, the scheduler will
484+ restart it through the nanny.
485+ """
486+ ws = s .workers [a .worker_address ]
487+
488+ async def f ():
489+ w = get_worker ()
490+ w .periodic_callbacks ["heartbeat" ].stop ()
491+ if block_evloop :
492+ sleep (9999 ) # Block event loop indefinitely
493+ else :
494+ await asyncio .sleep (9999 )
495+
496+ fut = c .submit (f , key = "x" )
497+
498+ while not s .workers or (
499+ (new_ws := next (iter (s .workers .values ()))) is ws
500+ or new_ws .status != Status .running
501+ ):
502+ await asyncio .sleep (0.01 )
503+
504+ if block_evloop :
505+ # The nanny killed the worker with SIGKILL.
506+ # The restart has increased the suspicious count.
507+ with pytest .raises (KilledWorker ):
508+ await fut
509+ assert s .tasks ["x" ].state == "erred"
510+ assert s .tasks ["x" ].suspicious == 1
511+ else :
512+ # The nanny sent to the WorkerProcess a {op: stop} through IPC, which in turn
513+ # successfully invoked Worker.close(nanny=False).
514+ # This behaviour makes sense as the worker-ttl timeout was most likely caused
515+ # by a failure in networking, rather than a hung process.
516+ assert s .tasks ["x" ].state == "processing"
517+ assert s .tasks ["x" ].suspicious == 0
518+
519+
520+ @pytest .mark .slow
521+ @gen_cluster (
522+ client = True ,
523+ Worker = Nanny ,
524+ nthreads = [("" , 2 )],
525+ scheduler_kwargs = {"allowed_failures" : 0 },
526+ )
527+ async def test_restart_hung_worker (c , s , a ):
528+ """Test restart_workers() to restart a worker whose event loop has become completely
529+ unresponsive.
530+ """
531+ ws = s .workers [a .worker_address ]
532+
533+ async def f ():
534+ w = get_worker ()
535+ w .periodic_callbacks ["heartbeat" ].stop ()
536+ sleep (9999 ) # Block event loop indefinitely
537+
538+ fut = c .submit (f )
539+ # Wait for worker to hang
540+ with pytest .raises (asyncio .TimeoutError ):
541+ while True :
542+ await wait (c .submit (inc , 1 , pure = False ), timeout = 0.2 )
543+
544+ await c .restart_workers ([a .worker_address ])
545+ assert len (s .workers ) == 1
546+ assert next (iter (s .workers .values ())) is not ws
472547
473548
474549@gen_cluster (client = True , nthreads = [("" , 1 )])
0 commit comments