Skip to content

Commit e0a7525

Browse files
authored
ensure workers are not downscaled when participating in p2p (#8610)
1 parent 66ced13 commit e0a7525

4 files changed

Lines changed: 80 additions & 4 deletions

File tree

distributed/diagnostics/plugin.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# circular imports
2424
from distributed.scheduler import Scheduler
2525
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
26+
from distributed.scheduler import WorkerState
2627
from distributed.worker import Worker
2728
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState
2829

@@ -205,6 +206,29 @@ def add_client(self, scheduler: Scheduler, client: str) -> None:
205206
def remove_client(self, scheduler: Scheduler, client: str) -> None:
206207
"""Run when a client disconnects"""
207208

209+
def valid_workers_downscaling(
210+
self, scheduler: Scheduler, workers: list[WorkerState]
211+
) -> list[WorkerState]:
212+
"""Determine which workers can be removed from the cluster
213+
214+
This method is called when the scheduler is about to downscale the cluster
215+
by removing workers. The method should return a set of worker states that
216+
can be removed from the cluster.
217+
218+
Parameters
219+
----------
220+
workers : list
221+
The list of worker states that are candidates for removal.
222+
stimulus_id : str
223+
ID of stimulus causing the downscaling.
224+
225+
Returns
226+
-------
227+
list
228+
The list of worker states that can be removed from the cluster.
229+
"""
230+
return workers
231+
208232
def log_event(self, topic: str, msg: Any) -> None:
209233
"""Run when an event is logged"""
210234

distributed/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7153,6 +7153,9 @@ def workers_to_close(
71537153
# running on, as it would cause them to restart from scratch
71547154
# somewhere else.
71557155
valid_workers = [ws for ws in self.workers.values() if not ws.long_running]
7156+
for plugin in list(self.plugins.values()):
7157+
valid_workers = plugin.valid_workers_downscaling(self, valid_workers)
7158+
71567159
groups = groupby(key, valid_workers)
71577160

71587161
limit_bytes = {k: sum(ws.memory_limit for ws in v) for k, v in groups.items()}

distributed/shuffle/_scheduler_plugin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,14 @@ def transition(
410410
if not archived:
411411
del self._archived_by_stimulus[shuffle._archived_by]
412412

413+
def valid_workers_downscaling(
414+
self, scheduler: Scheduler, workers: list[WorkerState]
415+
) -> list[WorkerState]:
416+
all_participating_workers = set()
417+
for shuffle in self.active_shuffles.values():
418+
all_participating_workers.update(shuffle.participating_workers)
419+
return [w for w in workers if w.address not in all_participating_workers]
420+
413421
def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None:
414422
worker_msgs = {
415423
worker: [

distributed/shuffle/tests/test_shuffle.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from packaging.version import parse as parse_version
1919
from tornado.ioloop import IOLoop
2020

21+
import dask
2122
from dask.utils import key_split
2223

2324
from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key
@@ -28,7 +29,6 @@
2829
import numpy as np
2930
import pandas as pd
3031

31-
import dask
3232
from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200
3333
from dask.typing import Key
3434

@@ -139,9 +139,10 @@ async def test_minimal_version(c, s, a, b):
139139
dtypes={"x": float, "y": float},
140140
freq="10 s",
141141
)
142-
with pytest.raises(
143-
ModuleNotFoundError, match="requires pyarrow"
144-
), dask.config.set({"dataframe.shuffle.method": "p2p"}):
142+
with (
143+
pytest.raises(ModuleNotFoundError, match="requires pyarrow"),
144+
dask.config.set({"dataframe.shuffle.method": "p2p"}),
145+
):
145146
await c.compute(df.shuffle("x"))
146147

147148

@@ -2795,3 +2796,43 @@ def data_gen():
27952796
"meta",
27962797
):
27972798
await c.gather(c.compute(ddf.shuffle(on="a")))
2799+
2800+
2801+
@gen_cluster(client=True)
2802+
async def test_dont_downscale_participating_workers(c, s, a, b):
2803+
df = dask.datasets.timeseries(
2804+
start="2000-01-01",
2805+
end="2000-01-10",
2806+
dtypes={"x": float, "y": float},
2807+
freq="10 s",
2808+
)
2809+
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
2810+
shuffled = df.shuffle("x")
2811+
2812+
workers_to_close = s.workers_to_close(n=2)
2813+
assert len(workers_to_close) == 2
2814+
res = c.compute(shuffled)
2815+
2816+
shuffle_id = await wait_until_new_shuffle_is_initialized(s)
2817+
while not get_active_shuffle_runs(a):
2818+
await asyncio.sleep(0.01)
2819+
while not get_active_shuffle_runs(b):
2820+
await asyncio.sleep(0.01)
2821+
2822+
workers_to_close = s.workers_to_close(n=2)
2823+
assert len(workers_to_close) == 0
2824+
2825+
async with Worker(s.address) as w:
2826+
c.submit(lambda: None, workers=[w.address])
2827+
2828+
workers_to_close = s.workers_to_close(n=3)
2829+
assert len(workers_to_close) == 1
2830+
2831+
workers_to_close = s.workers_to_close(n=2)
2832+
assert len(workers_to_close) == 0
2833+
2834+
await c.gather(res)
2835+
del res
2836+
2837+
workers_to_close = s.workers_to_close(n=2)
2838+
assert len(workers_to_close) == 2

0 commit comments

Comments
 (0)