From afc63a7cfe7aff607e619af87cf1f069765354b3 Mon Sep 17 00:00:00 2001 From: Artem Chirkin <9253178+achirkin@users.noreply.github.com> Date: Thu, 11 Jun 2026 08:06:21 -0700 Subject: [PATCH] Make concurrent streams wait on the dataset descriptor initialization stream --- cpp/src/neighbors/detail/cagra/compute_distance.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 75b56860bb..2d4ccc00e3 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -217,10 +217,12 @@ struct dataset_descriptor_host { std::mutex mutex; std::atomic ready; // Not sure if std::holds_alternative is thread-safe std::variant value; + cudaEvent_t ready_event; template state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)} { + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); } ~state() noexcept @@ -229,6 +231,7 @@ struct dataset_descriptor_host { auto& [ptr, stream] = std::get(value); RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream)); } + RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); } void eval(rmm::cuda_stream_view stream) @@ -239,6 +242,7 @@ struct dataset_descriptor_host { dev_descriptor_t* ptr = nullptr; RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream)); fun(ptr, stream); + RAFT_CUDA_TRY(cudaEventRecord(ready_event, stream)); value = std::make_tuple(ptr, stream); ready.store(true, std::memory_order_release); } @@ -247,7 +251,10 @@ struct dataset_descriptor_host { auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t* { if (!ready.load(std::memory_order_acquire)) { eval(stream); } - return std::get<0>(std::get(value)); + // value is immutable at this point. + auto& [ptr, ready_stream] = std::get(value); + if (ready_stream != stream) { RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, ready_event, 0)); } + return ptr; } };