@@ -1232,12 +1232,90 @@ async def test_head(c, s, a, b):
12321232
12331233
12341234def test_split_by_worker ():
1235- workers = ["a" , "b" , "c" ]
1236- npartitions = 5
1237- df = pd .DataFrame ({"x" : range (100 ), "y" : range (100 )})
1238- df ["_partitions" ] = df .x % npartitions
1239- worker_for = {i : random .choice (workers ) for i in range (npartitions )}
1240- s = pd .Series (worker_for , name = "_worker" ).astype ("category" )
1235+ pytest .importorskip ("pyarrow" )
1236+
1237+ df = pd .DataFrame (
1238+ {
1239+ "x" : [1 , 2 , 3 , 4 , 5 ],
1240+ "_partition" : [0 , 1 , 2 , 0 , 1 ],
1241+ }
1242+ )
1243+ meta = df [["x" ]].head (0 )
1244+ workers = ["alice" , "bob" ]
1245+ worker_for_mapping = {}
1246+ npartitions = 3
1247+ for part in range (npartitions ):
1248+ worker_for_mapping [part ] = _get_worker_for_range_sharding (
1249+ npartitions , part , workers
1250+ )
1251+ worker_for = pd .Series (worker_for_mapping , name = "_workers" ).astype ("category" )
1252+ out = split_by_worker (df , "_partition" , meta , worker_for )
1253+ assert set (out ) == {"alice" , "bob" }
1254+ assert list (out ["alice" ].to_pandas ().columns ) == list (df .columns )
1255+
1256+ assert sum (map (len , out .values ())) == len (df )
1257+
1258+
1259+ def test_split_by_worker_empty ():
1260+ pytest .importorskip ("pyarrow" )
1261+
1262+ df = pd .DataFrame (
1263+ {
1264+ "x" : [1 , 2 , 3 , 4 , 5 ],
1265+ "_partition" : [0 , 1 , 2 , 0 , 1 ],
1266+ }
1267+ )
1268+ meta = df [["x" ]].head (0 )
1269+ worker_for = pd .Series ({5 : "chuck" }, name = "_workers" ).astype ("category" )
1270+ out = split_by_worker (df , "_partition" , meta , worker_for )
1271+ assert out == {}
1272+
1273+
1274+ def test_split_by_worker_many_workers ():
1275+ pytest .importorskip ("pyarrow" )
1276+
1277+ df = pd .DataFrame (
1278+ {
1279+ "x" : [1 , 2 , 3 , 4 , 5 ],
1280+ "_partition" : [5 , 7 , 5 , 0 , 1 ],
1281+ }
1282+ )
1283+ meta = df [["x" ]].head (0 )
1284+ workers = ["a" , "b" , "c" , "d" , "e" , "f" , "g" , "h" ]
1285+ npartitions = 10
1286+ worker_for_mapping = {}
1287+ for part in range (npartitions ):
1288+ worker_for_mapping [part ] = _get_worker_for_range_sharding (
1289+ npartitions , part , workers
1290+ )
1291+ worker_for = pd .Series (worker_for_mapping , name = "_workers" ).astype ("category" )
1292+ out = split_by_worker (df , "_partition" , meta , worker_for )
1293+ assert _get_worker_for_range_sharding (npartitions , 5 , workers ) in out
1294+ assert _get_worker_for_range_sharding (npartitions , 0 , workers ) in out
1295+ assert _get_worker_for_range_sharding (npartitions , 7 , workers ) in out
1296+ assert _get_worker_for_range_sharding (npartitions , 1 , workers ) in out
1297+
1298+ assert sum (map (len , out .values ())) == len (df )
1299+
1300+
1301+ @pytest .mark .parametrize ("drop_column" , [True , False ])
1302+ def test_split_by_partition (drop_column ):
1303+ pa = pytest .importorskip ("pyarrow" )
1304+
1305+ df = pd .DataFrame (
1306+ {
1307+ "x" : [1 , 2 , 3 , 4 , 5 ],
1308+ "_partition" : [3 , 1 , 2 , 3 , 1 ],
1309+ }
1310+ )
1311+ t = pa .Table .from_pandas (df )
1312+
1313+ out = split_by_partition (t , "_partition" , drop_column )
1314+ assert set (out ) == {1 , 2 , 3 }
1315+ if drop_column :
1316+ df = df .drop (columns = "_partition" )
1317+ assert out [1 ].column_names == list (df .columns )
1318+ assert sum (map (len , out .values ())) == len (df )
12411319
12421320
12431321@gen_cluster (client = True , nthreads = [("" , 1 )] * 2 )
0 commit comments