diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index ec4a684274..774854ad19 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -170,6 +171,7 @@ template auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) -> raft::device_matrix { + using kmeans_in_type = typename DatasetT::value_type; const ix_t n_rows = dataset.extent(0); const ix_t vq_n_centers = params.vq_n_centers; const ix_t dim = dataset.extent(1); @@ -181,16 +183,21 @@ auto train_vq(const raft::resources& res, const vpq_params& params, const Datase auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); - using kmeans_in_type = typename DatasetT::value_type; - cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - auto vq_centers_view = - raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); auto vq_trainset_view = raft::make_device_matrix_view( vq_trainset.data_handle(), n_rows_train, dim); - cuvs::cluster::kmeans::fit(res, kmeans_params, vq_trainset_view, vq_centers_view); + if (vq_n_centers == 1) { + auto vq_centers_view = + raft::make_device_vector_view(vq_centers.data_handle(), dim); + raft::stats::mean(res, vq_trainset_view, vq_centers_view); + } else { + auto vq_centers_view = + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; + cuvs::cluster::kmeans::fit(res, kmeans_params, vq_trainset_view, vq_centers_view); + } return vq_centers; }