Skip to content

Commit 0e2be4a

Browse files
authored
Fix race condition for published futures with annotations (#8577)
1 parent 42c479f commit 0e2be4a

3 files changed

Lines changed: 50 additions & 3 deletions

File tree

distributed/client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2650,18 +2650,23 @@ def retry(self, futures, asynchronous=None):
26502650
@log_errors
26512651
async def _publish_dataset(self, *args, name=None, override=False, **kwargs):
26522652
coroutines = []
2653+
uid = uuid.uuid4().hex
2654+
self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid})
26532655

26542656
def add_coro(name, data):
26552657
keys = [f.key for f in futures_of(data)]
2656-
coroutines.append(
2657-
self.scheduler.publish_put(
2658+
2659+
async def _():
2660+
await self.scheduler.publish_wait_flush(uid=uid)
2661+
await self.scheduler.publish_put(
26582662
keys=keys,
26592663
name=name,
26602664
data=to_serialize(data),
26612665
override=override,
26622666
client=self.id,
26632667
)
2664-
)
2668+
2669+
coroutines.append(_())
26652670

26662671
if name:
26672672
if len(args) == 0:

distributed/publish.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
from collections import defaultdict
35
from collections.abc import MutableMapping
46

57
from dask.utils import stringify
@@ -25,9 +27,21 @@ def __init__(self, scheduler):
2527
"publish_put": self.put,
2628
"publish_get": self.get,
2729
"publish_delete": self.delete,
30+
"publish_wait_flush": self.flush_wait,
31+
}
32+
stream_handlers = {
33+
"publish_flush_batched_send": self.flush_receive,
2834
}
2935

3036
self.scheduler.handlers.update(handlers)
37+
self.scheduler.stream_handlers.update(stream_handlers)
38+
self._flush_received = defaultdict(asyncio.Event)
39+
40+
def flush_receive(self, uid, **kwargs):
41+
self._flush_received[uid].set()
42+
43+
async def flush_wait(self, uid):
44+
await self._flush_received[uid].wait()
3145

3246
@log_errors
3347
def put(self, keys=None, data=None, name=None, override=False, client=None):

distributed/tests/test_publish.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from distributed.metrics import time
1212
from distributed.protocol import Serialized
1313
from distributed.utils_test import gen_cluster, inc
14+
from distributed.worker import get_worker
1415

1516

1617
@gen_cluster()
@@ -301,3 +302,30 @@ async def test_deserialize_client(c, s, a, b):
301302
from distributed.client import _current_client
302303

303304
assert _current_client.get() is c
305+
306+
307+
@gen_cluster(client=True, worker_kwargs={"resources": {"A": 1}})
308+
async def test_publish_submit_ordering(c, s, a, b):
309+
RESOURCES = {"A": 1}
310+
311+
def _retrieve_annotations():
312+
worker = get_worker()
313+
task = worker.state.tasks.get(worker.get_current_task())
314+
return task.annotations
315+
316+
# If publish does not take the same comm channel as the submit, it can
317+
# happen that the publish message reaches the scheduler before the submit
318+
# such that the state of the published future is not the one that has been
319+
# requested from the submit. Particularly, this lets us drop annotations
320+
# The current implementation does in fact not use the same channel due to
321+
# serialization issue (including Futures in BatchedSend appends them to the
322+
# "recent messages" log which screws with the refcounting) but ensure that
323+
# all queued up messages are flushed and received by the schduler befure
324+
# publishing
325+
future = c.submit(_retrieve_annotations, resources=RESOURCES, pure=False)
326+
327+
await c.publish_dataset(future, name="foo")
328+
assert await c.list_datasets() == ("foo",)
329+
330+
result = await future.result()
331+
assert result == {"resources": RESOURCES}

0 commit comments

Comments
 (0)