From 9d685b5de3e04f0270649438cc8682f3640d976c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 22 Jun 2026 14:59:58 +0000 Subject: [PATCH 1/3] Switch VQ to mean() when n_centers = 1 Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vpq_dataset.cuh | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index ec4a684274..ac2476af98 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -170,6 +170,7 @@ template auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) -> raft::device_matrix { + using kmeans_in_type = typename DatasetT::value_type; const ix_t n_rows = dataset.extent(0); const ix_t vq_n_centers = params.vq_n_centers; const ix_t dim = dataset.extent(1); @@ -181,16 +182,19 @@ auto train_vq(const raft::resources& res, const vpq_params& params, const Datase auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); - using kmeans_in_type = typename DatasetT::value_type; - cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; auto vq_centers_view = raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); auto vq_trainset_view = raft::make_device_matrix_view( vq_trainset.data_handle(), n_rows_train, dim); - cuvs::cluster::kmeans::fit(res, kmeans_params, vq_trainset_view, vq_centers_view); + if (vq_n_centers == 1) { + raft::stats::mean(res, vq_trainset_view, vq_centers_view); + } else { + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; + cuvs::cluster::kmeans::fit(res, kmeans_params, vq_trainset_view, vq_centers_view); + } return vq_centers; } From e09106780ab9d1a5bca8efde5af59e56f742b2bb Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 23 Jun 2026 12:56:09 +0000 Subject: [PATCH 2/3] Add include Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vpq_dataset.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index ac2476af98..6054061256 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include From 9b5bd2adb59a912cc1dfd5748a2a69097dfd6224 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 23 Jun 2026 20:33:10 +0000 Subject: [PATCH 3/3] Change output mdspan Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vpq_dataset.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 6054061256..774854ad19 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -183,14 +183,16 @@ auto train_vq(const raft::resources& res, const vpq_params& params, const Datase auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); - auto vq_centers_view = - raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); auto vq_trainset_view = raft::make_device_matrix_view( vq_trainset.data_handle(), n_rows_train, dim); if (vq_n_centers == 1) { + auto vq_centers_view = + raft::make_device_vector_view(vq_centers.data_handle(), dim); raft::stats::mean(res, vq_trainset_view, vq_centers_view); } else { + auto vq_centers_view = + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;