|
9 | 9 | from itertools import cycle |
10 | 10 | from typing import Any, TypeVar |
11 | 11 |
|
12 | | -from tlz import concat, drop, groupby, merge |
| 12 | +from tlz import drop, groupby, merge |
13 | 13 |
|
14 | 14 | import dask.config |
15 | 15 | from dask.optimization import SubgraphCallable |
@@ -151,19 +151,16 @@ def __repr__(self): |
151 | 151 | _round_robin_counter = [0] |
152 | 152 |
|
153 | 153 |
|
154 | | -async def scatter_to_workers(nthreads, data, rpc=rpc): |
| 154 | +async def scatter_to_workers(workers, data, rpc=rpc): |
155 | 155 | """Scatter data directly to workers |
156 | 156 |
|
157 | | - This distributes data in a round-robin fashion to a set of workers based on |
158 | | - how many cores they have. nthreads should be a dictionary mapping worker |
159 | | - identities to numbers of cores. |
| 157 | + This distributes data in a round-robin fashion to a set of workers. |
160 | 158 |
|
161 | 159 | See scatter for parameter docstring |
162 | 160 | """ |
163 | | - assert isinstance(nthreads, dict) |
164 | 161 | assert isinstance(data, dict) |
165 | 162 |
|
166 | | - workers = list(concat([w] * nc for w, nc in nthreads.items())) |
| 163 | + workers = sorted(workers) |
167 | 164 | names, data = list(zip(*data.items())) |
168 | 165 |
|
169 | 166 | worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers)) |
|
0 commit comments