Skip to content

Commit 4986fa4

Browse files
Move tests (#8631)
1 parent 5a588ae commit 4986fa4

File tree

2 files changed

+84
-98
lines changed

2 files changed

+84
-98
lines changed

distributed/shuffle/tests/test_shuffle.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,12 +1232,90 @@ async def test_head(c, s, a, b):
12321232

12331233

12341234
def test_split_by_worker():
1235-
workers = ["a", "b", "c"]
1236-
npartitions = 5
1237-
df = pd.DataFrame({"x": range(100), "y": range(100)})
1238-
df["_partitions"] = df.x % npartitions
1239-
worker_for = {i: random.choice(workers) for i in range(npartitions)}
1240-
s = pd.Series(worker_for, name="_worker").astype("category")
1235+
pytest.importorskip("pyarrow")
1236+
1237+
df = pd.DataFrame(
1238+
{
1239+
"x": [1, 2, 3, 4, 5],
1240+
"_partition": [0, 1, 2, 0, 1],
1241+
}
1242+
)
1243+
meta = df[["x"]].head(0)
1244+
workers = ["alice", "bob"]
1245+
worker_for_mapping = {}
1246+
npartitions = 3
1247+
for part in range(npartitions):
1248+
worker_for_mapping[part] = _get_worker_for_range_sharding(
1249+
npartitions, part, workers
1250+
)
1251+
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
1252+
out = split_by_worker(df, "_partition", meta, worker_for)
1253+
assert set(out) == {"alice", "bob"}
1254+
assert list(out["alice"].to_pandas().columns) == list(df.columns)
1255+
1256+
assert sum(map(len, out.values())) == len(df)
1257+
1258+
1259+
def test_split_by_worker_empty():
1260+
pytest.importorskip("pyarrow")
1261+
1262+
df = pd.DataFrame(
1263+
{
1264+
"x": [1, 2, 3, 4, 5],
1265+
"_partition": [0, 1, 2, 0, 1],
1266+
}
1267+
)
1268+
meta = df[["x"]].head(0)
1269+
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
1270+
out = split_by_worker(df, "_partition", meta, worker_for)
1271+
assert out == {}
1272+
1273+
1274+
def test_split_by_worker_many_workers():
1275+
pytest.importorskip("pyarrow")
1276+
1277+
df = pd.DataFrame(
1278+
{
1279+
"x": [1, 2, 3, 4, 5],
1280+
"_partition": [5, 7, 5, 0, 1],
1281+
}
1282+
)
1283+
meta = df[["x"]].head(0)
1284+
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
1285+
npartitions = 10
1286+
worker_for_mapping = {}
1287+
for part in range(npartitions):
1288+
worker_for_mapping[part] = _get_worker_for_range_sharding(
1289+
npartitions, part, workers
1290+
)
1291+
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
1292+
out = split_by_worker(df, "_partition", meta, worker_for)
1293+
assert _get_worker_for_range_sharding(npartitions, 5, workers) in out
1294+
assert _get_worker_for_range_sharding(npartitions, 0, workers) in out
1295+
assert _get_worker_for_range_sharding(npartitions, 7, workers) in out
1296+
assert _get_worker_for_range_sharding(npartitions, 1, workers) in out
1297+
1298+
assert sum(map(len, out.values())) == len(df)
1299+
1300+
1301+
@pytest.mark.parametrize("drop_column", [True, False])
1302+
def test_split_by_partition(drop_column):
1303+
pa = pytest.importorskip("pyarrow")
1304+
1305+
df = pd.DataFrame(
1306+
{
1307+
"x": [1, 2, 3, 4, 5],
1308+
"_partition": [3, 1, 2, 3, 1],
1309+
}
1310+
)
1311+
t = pa.Table.from_pandas(df)
1312+
1313+
out = split_by_partition(t, "_partition", drop_column)
1314+
assert set(out) == {1, 2, 3}
1315+
if drop_column:
1316+
df = df.drop(columns="_partition")
1317+
assert out[1].column_names == list(df.columns)
1318+
assert sum(map(len, out.values())) == len(df)
12411319

12421320

12431321
@gen_cluster(client=True, nthreads=[("", 1)] * 2)

distributed/shuffle/tests/test_shuffle_plugins.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55
import pytest
66

77
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
8-
from distributed.shuffle._shuffle import (
9-
_get_worker_for_range_sharding,
10-
split_by_partition,
11-
split_by_worker,
12-
)
138
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
149
from distributed.utils_test import gen_cluster
1510

@@ -35,90 +30,3 @@ async def test_installation_on_scheduler(s, a):
3530
assert isinstance(ext, ShuffleSchedulerPlugin)
3631
assert s.handlers["shuffle_barrier"] == ext.barrier
3732
assert s.handlers["shuffle_get"] == ext.get
38-
39-
40-
def test_split_by_worker():
41-
pytest.importorskip("pyarrow")
42-
43-
df = pd.DataFrame(
44-
{
45-
"x": [1, 2, 3, 4, 5],
46-
"_partition": [0, 1, 2, 0, 1],
47-
}
48-
)
49-
meta = df[["x"]].head(0)
50-
workers = ["alice", "bob"]
51-
worker_for_mapping = {}
52-
npartitions = 3
53-
for part in range(npartitions):
54-
worker_for_mapping[part] = _get_worker_for_range_sharding(
55-
npartitions, part, workers
56-
)
57-
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
58-
out = split_by_worker(df, "_partition", meta, worker_for)
59-
assert set(out) == {"alice", "bob"}
60-
assert list(out["alice"].to_pandas().columns) == list(df.columns)
61-
62-
assert sum(map(len, out.values())) == len(df)
63-
64-
65-
def test_split_by_worker_empty():
66-
pytest.importorskip("pyarrow")
67-
68-
df = pd.DataFrame(
69-
{
70-
"x": [1, 2, 3, 4, 5],
71-
"_partition": [0, 1, 2, 0, 1],
72-
}
73-
)
74-
meta = df[["x"]].head(0)
75-
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
76-
out = split_by_worker(df, "_partition", meta, worker_for)
77-
assert out == {}
78-
79-
80-
def test_split_by_worker_many_workers():
81-
pytest.importorskip("pyarrow")
82-
83-
df = pd.DataFrame(
84-
{
85-
"x": [1, 2, 3, 4, 5],
86-
"_partition": [5, 7, 5, 0, 1],
87-
}
88-
)
89-
meta = df[["x"]].head(0)
90-
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
91-
npartitions = 10
92-
worker_for_mapping = {}
93-
for part in range(npartitions):
94-
worker_for_mapping[part] = _get_worker_for_range_sharding(
95-
npartitions, part, workers
96-
)
97-
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
98-
out = split_by_worker(df, "_partition", meta, worker_for)
99-
assert _get_worker_for_range_sharding(npartitions, 5, workers) in out
100-
assert _get_worker_for_range_sharding(npartitions, 0, workers) in out
101-
assert _get_worker_for_range_sharding(npartitions, 7, workers) in out
102-
assert _get_worker_for_range_sharding(npartitions, 1, workers) in out
103-
104-
assert sum(map(len, out.values())) == len(df)
105-
106-
107-
@pytest.mark.parametrize("drop_column", [True, False])
108-
def test_split_by_partition(drop_column):
109-
pa = pytest.importorskip("pyarrow")
110-
111-
df = pd.DataFrame(
112-
{
113-
"x": [1, 2, 3, 4, 5],
114-
"_partition": [3, 1, 2, 3, 1],
115-
}
116-
)
117-
t = pa.Table.from_pandas(df)
118-
119-
out = split_by_partition(t, "_partition", drop_column)
120-
assert set(out) == {1, 2, 3}
121-
if drop_column:
122-
df = df.drop(columns="_partition")
123-
assert out[1].column_names == list(df.columns)
124-
assert sum(map(len, out.values())) == len(df)

0 commit comments

Comments
 (0)