Skip to content

Commit 42c479f

Browse files
authored
Scatter by worker instead of worker->nthreads (#8590)
* Scatter round-robin by worker Not by worker->nthreads * Refactor requiring nthreads to scatter_to_workers
1 parent 0f2290b commit 42c479f

3 files changed

Lines changed: 9 additions & 14 deletions

File tree

distributed/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,10 +2450,9 @@ async def _scatter(
24502450
nthreads = await self.scheduler.ncores_running(workers=workers)
24512451
if not nthreads: # pragma: no cover
24522452
raise ValueError("No valid workers found")
2453+
workers = list(nthreads.keys())
24532454

2454-
_, who_has, nbytes = await scatter_to_workers(
2455-
nthreads, data2, rpc=self.rpc
2456-
)
2455+
_, who_has, nbytes = await scatter_to_workers(workers, data2, self.rpc)
24572456

24582457
await self.scheduler.update_data(
24592458
who_has=who_has, nbytes=nbytes, client=self.id

distributed/scheduler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6132,16 +6132,15 @@ async def scatter(
61326132
raise TimeoutError("No valid workers found")
61336133
await asyncio.sleep(0.1)
61346134

6135-
nthreads = {ws.address: ws.nthreads for ws in wss}
6136-
61376135
assert isinstance(data, dict)
61386136

6139-
keys, who_has, nbytes = await scatter_to_workers(nthreads, data, rpc=self.rpc)
6137+
workers = list(ws.address for ws in wss)
6138+
keys, who_has, nbytes = await scatter_to_workers(workers, data, rpc=self.rpc)
61406139

61416140
self.update_data(who_has=who_has, nbytes=nbytes, client=client)
61426141

61436142
if broadcast:
6144-
n = len(nthreads) if broadcast is True else broadcast
6143+
n = len(workers) if broadcast is True else broadcast
61456144
await self.replicate(keys=keys, workers=workers, n=n)
61466145

61476146
self.log_event(

distributed/utils_comm.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from itertools import cycle
1010
from typing import Any, TypeVar
1111

12-
from tlz import concat, drop, groupby, merge
12+
from tlz import drop, groupby, merge
1313

1414
import dask.config
1515
from dask.optimization import SubgraphCallable
@@ -151,19 +151,16 @@ def __repr__(self):
151151
_round_robin_counter = [0]
152152

153153

154-
async def scatter_to_workers(nthreads, data, rpc=rpc):
154+
async def scatter_to_workers(workers, data, rpc=rpc):
155155
"""Scatter data directly to workers
156156
157-
This distributes data in a round-robin fashion to a set of workers based on
158-
how many cores they have. nthreads should be a dictionary mapping worker
159-
identities to numbers of cores.
157+
This distributes data in a round-robin fashion to a set of workers.
160158
161159
See scatter for parameter docstring
162160
"""
163-
assert isinstance(nthreads, dict)
164161
assert isinstance(data, dict)
165162

166-
workers = list(concat([w] * nc for w, nc in nthreads.items()))
163+
workers = sorted(workers)
167164
names, data = list(zip(*data.items()))
168165

169166
worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers))

0 commit comments

Comments
 (0)