|
18 | 18 | from packaging.version import parse as parse_version |
19 | 19 | from tornado.ioloop import IOLoop |
20 | 20 |
|
| 21 | +import dask |
21 | 22 | from dask.utils import key_split |
22 | 23 |
|
23 | 24 | from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key |
|
28 | 29 | import numpy as np |
29 | 30 | import pandas as pd |
30 | 31 |
|
31 | | -import dask |
32 | 32 | from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200 |
33 | 33 | from dask.typing import Key |
34 | 34 |
|
@@ -139,9 +139,10 @@ async def test_minimal_version(c, s, a, b): |
139 | 139 | dtypes={"x": float, "y": float}, |
140 | 140 | freq="10 s", |
141 | 141 | ) |
142 | | - with pytest.raises( |
143 | | - ModuleNotFoundError, match="requires pyarrow" |
144 | | - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): |
| 142 | + with ( |
| 143 | + pytest.raises(ModuleNotFoundError, match="requires pyarrow"), |
| 144 | + dask.config.set({"dataframe.shuffle.method": "p2p"}), |
| 145 | + ): |
145 | 146 | await c.compute(df.shuffle("x")) |
146 | 147 |
|
147 | 148 |
|
@@ -2795,3 +2796,43 @@ def data_gen(): |
2795 | 2796 | "meta", |
2796 | 2797 | ): |
2797 | 2798 | await c.gather(c.compute(ddf.shuffle(on="a"))) |
| 2799 | + |
| 2800 | + |
| 2801 | +@gen_cluster(client=True) |
| 2802 | +async def test_dont_downscale_participating_workers(c, s, a, b): |
| 2803 | + df = dask.datasets.timeseries( |
| 2804 | + start="2000-01-01", |
| 2805 | + end="2000-01-10", |
| 2806 | + dtypes={"x": float, "y": float}, |
| 2807 | + freq="10 s", |
| 2808 | + ) |
| 2809 | + with dask.config.set({"dataframe.shuffle.method": "p2p"}): |
| 2810 | + shuffled = df.shuffle("x") |
| 2811 | + |
| 2812 | + workers_to_close = s.workers_to_close(n=2) |
| 2813 | + assert len(workers_to_close) == 2 |
| 2814 | + res = c.compute(shuffled) |
| 2815 | + |
| 2816 | + shuffle_id = await wait_until_new_shuffle_is_initialized(s) |
| 2817 | + while not get_active_shuffle_runs(a): |
| 2818 | + await asyncio.sleep(0.01) |
| 2819 | + while not get_active_shuffle_runs(b): |
| 2820 | + await asyncio.sleep(0.01) |
| 2821 | + |
| 2822 | + workers_to_close = s.workers_to_close(n=2) |
| 2823 | + assert len(workers_to_close) == 0 |
| 2824 | + |
| 2825 | + async with Worker(s.address) as w: |
| 2826 | + c.submit(lambda: None, workers=[w.address]) |
| 2827 | + |
| 2828 | + workers_to_close = s.workers_to_close(n=3) |
| 2829 | + assert len(workers_to_close) == 1 |
| 2830 | + |
| 2831 | + workers_to_close = s.workers_to_close(n=2) |
| 2832 | + assert len(workers_to_close) == 0 |
| 2833 | + |
| 2834 | + await c.gather(res) |
| 2835 | + del res |
| 2836 | + |
| 2837 | + workers_to_close = s.workers_to_close(n=2) |
| 2838 | + assert len(workers_to_close) == 2 |
0 commit comments