diff --git a/c/CMakeLists.txt b/c/CMakeLists.txt index be4bc7a051..0f3e5be5bf 100644 --- a/c/CMakeLists.txt +++ b/c/CMakeLists.txt @@ -85,6 +85,7 @@ endif() add_library( cuvs_c SHARED src/core/c_api.cpp + src/cluster/gmm.cpp src/cluster/kmeans.cpp src/neighbors/brute_force.cpp src/neighbors/ivf_flat.cpp diff --git a/c/include/cuvs/cluster/gmm.h b/c/include/cuvs/cluster/gmm.h new file mode 100644 index 0000000000..1609fa6652 --- /dev/null +++ b/c/include/cuvs/cluster/gmm.h @@ -0,0 +1,260 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @defgroup gmm_c_params Gaussian mixture hyperparameters + * @{ + */ + +/** + * @brief Covariance parameterization of the mixture components. + */ +typedef enum { + /** Each component has its own full covariance matrix. */ + CUVS_GMM_COVARIANCE_FULL = 0, + /** All components share a single full covariance matrix. */ + CUVS_GMM_COVARIANCE_TIED = 1, + /** Each component has its own diagonal covariance. */ + CUVS_GMM_COVARIANCE_DIAG = 2, + /** Each component has a single variance. */ + CUVS_GMM_COVARIANCE_SPHERICAL = 3 +} cuvsGMMCovarianceType; + +/** + * @brief Strategy used to initialize the responsibilities before EM. + */ +typedef enum { + /** Run k-means (itself seeded with k-means++) and use the hard labels. */ + CUVS_GMM_INIT_KMEANS = 0, + /** Use the k-means++ seeding labels directly. */ + CUVS_GMM_INIT_KMEANS_PLUS_PLUS = 1, + /** Random per-sample-normalized responsibilities. */ + CUVS_GMM_INIT_RANDOM = 2, + /** Pick n_components samples at random as one-hot responsibilities. */ + CUVS_GMM_INIT_RANDOM_FROM_DATA = 3 +} cuvsGMMInitMethod; + +/** + * @brief Hyper-parameters for the Gaussian mixture EM solver + */ +struct cuvsGMMParams { + /** + * The number of mixture components. Default: 1. + */ + int n_components; + + /** + * Covariance parameterization of the mixture components. Default: FULL. + */ + cuvsGMMCovarianceType covariance_type; + + /** + * Convergence threshold on the change of the per-sample average + * log-likelihood (lower bound). Default: 1e-3. + */ + double tol; + + /** + * Non-negative regularization added to the diagonal of covariance. + * Default: 1e-6. + */ + double reg_covar; + + /** + * Maximum number of EM iterations for a single run. Default: 100. + */ + int max_iter; + + /** + * Number of initializations to perform; the best result is kept. Default: 1. + */ + int n_init; + + /** + * Strategy used to initialize the responsibilities before EM. + * Default: KMEANS. + */ + cuvsGMMInitMethod init; + + /** + * Seed to the random number generator. Default: 0. + */ + uint64_t seed; +}; + +typedef struct cuvsGMMParams* cuvsGMMParams_t; + +/** + * @brief Allocate GMM params, and populate with default values + * + * @param[in] params cuvsGMMParams_t to allocate + * @return cuvsError_t + */ +CUVS_EXPORT cuvsError_t cuvsGMMParamsCreate(cuvsGMMParams_t* params); + +/** + * @brief De-allocate GMM params + * + * @param[in] params + * @return cuvsError_t + */ +CUVS_EXPORT cuvsError_t cuvsGMMParamsDestroy(cuvsGMMParams_t params); + +/** + * @} + */ + +/** + * @defgroup gmm_c Gaussian mixture model APIs + * @{ + * + * The covariance-shaped tensors (``covariances``, ``precisions_chol``, + * ``precisions``) depend on ``covariance_type``. With ``K = n_components`` + * and ``d = n_features`` the expected shapes are (row-major): + * + * - ``CUVS_GMM_COVARIANCE_FULL``: (K, d, d) + * - ``CUVS_GMM_COVARIANCE_TIED``: (d, d) + * - ``CUVS_GMM_COVARIANCE_DIAG``: (K, d) + * - ``CUVS_GMM_COVARIANCE_SPHERICAL``: (K,) + */ + +/** + * @brief Fit a Gaussian mixture with the EM algorithm. + * + * Runs ``params->n_init`` random restarts (unless ``warm_start`` is true) and + * keeps the parameters with the largest lower bound. + * + * All tensors must reside on device memory and be row-major. ``X``, + * ``weights``, ``means``, ``covariances``, ``precisions_chol`` and + * ``precisions`` must share one dtype (float32 or float64); ``labels`` is + * int32. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for the GMM model. + * @param[in] X Training data. [dim = n_samples x n_features] + * @param[inout] weights Mixture weights. [len = n_components] + * @param[inout] means Component means. + * [dim = n_components x n_features] + * @param[inout] covariances Component covariances, flat. Length by + * covariance_type (K=n_components, d=n_features): + * FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] precisions_chol Precision Cholesky factors, same flat layout as + * covariances (FULL/TIED: upper-triangular factor + * U with precision = U @ Uᵀ; DIAG/SPHERICAL: + * reciprocal standard deviations). + * @param[out] precisions Precision matrices, same flat layout as + * covariances. + * @param[out] labels Hard component assignment per sample. + * [len = n_samples] + * @param[out] lower_bound Per-sample average log-likelihood of the best + * fit. + * @param[out] n_iter Number of EM iterations of the best fit. + * @param[out] converged Whether the best fit converged within tol. + * @param[in] warm_start Use the incoming weights/means/covariances as + * the single initialization. + */ +CUVS_EXPORT cuvsError_t cuvsGMMFit(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* covariances, + DLManagedTensor* precisions_chol, + DLManagedTensor* precisions, + DLManagedTensor* labels, + double* lower_bound, + int* n_iter, + bool* converged, + bool warm_start); + +/** + * @brief Hard component labels (argmax responsibility) for new data. + * + * @param[in] res opaque C handle + * @param[in] params Parameters used to fit the GMM model. + * @param[in] X Data to assign. [dim = n_samples x n_features] + * @param[in] weights Fitted mixture weights. [len = n_components] + * @param[in] means Fitted component means. + * [dim = n_components x n_features] + * @param[in] precisions_chol Fitted precision Cholesky factors, flat. Length + * by covariance_type (K=n_components, d=n_features): + * FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] labels Hard component assignment per sample (int32). + * [len = n_samples] + */ +CUVS_EXPORT cuvsError_t cuvsGMMPredict(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* labels); + +/** + * @brief Posterior responsibilities for new data. + * + * @param[in] res opaque C handle + * @param[in] params Parameters used to fit the GMM model. + * @param[in] X Data to evaluate. [dim = n_samples x n_features] + * @param[in] weights Fitted mixture weights. [len = n_components] + * @param[in] means Fitted component means. + * [dim = n_components x n_features] + * @param[in] precisions_chol Fitted precision Cholesky factors, flat. Length + * by covariance_type (K=n_components, d=n_features): + * FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] resp Posterior probability of each component for + * each sample. [dim = n_samples x n_components] + */ +CUVS_EXPORT cuvsError_t cuvsGMMPredictProba(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* resp); + +/** + * @brief Per-sample log-likelihood log p(x_i) for new data. + * + * @param[in] res opaque C handle + * @param[in] params Parameters used to fit the GMM model. + * @param[in] X Data to evaluate. [dim = n_samples x n_features] + * @param[in] weights Fitted mixture weights. [len = n_components] + * @param[in] means Fitted component means. + * [dim = n_components x n_features] + * @param[in] precisions_chol Fitted precision Cholesky factors, flat. Length + * by covariance_type (K=n_components, d=n_features): + * FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] log_prob_norm Log-likelihood of each sample under the model. + * [len = n_samples] + */ +CUVS_EXPORT cuvsError_t cuvsGMMScoreSamples(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* log_prob_norm); + +/** + * @} + */ + +#ifdef __cplusplus +} +#endif diff --git a/c/include/cuvs/core/all.h b/c/include/cuvs/core/all.h index 545c7ec6f4..a99a14ed30 100644 --- a/c/include/cuvs/core/all.h +++ b/c/include/cuvs/core/all.h @@ -12,6 +12,7 @@ #include #include +#include #include #include diff --git a/c/src/cluster/gmm.cpp b/c/src/cluster/gmm.cpp new file mode 100644 index 0000000000..7ec50b65ea --- /dev/null +++ b/c/src/cluster/gmm.cpp @@ -0,0 +1,346 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include +#include +#include + +#include "../core/exceptions.hpp" +#include "../core/interop.hpp" + +namespace { + +cuvs::cluster::gmm::params convert_params(const cuvsGMMParams& params) +{ + auto gmm_params = cuvs::cluster::gmm::params(); + gmm_params.n_components = params.n_components; + gmm_params.cov_type = static_cast(params.covariance_type); + gmm_params.tol = params.tol; + gmm_params.reg_covar = params.reg_covar; + gmm_params.max_iter = params.max_iter; + gmm_params.n_init = params.n_init; + gmm_params.init = static_cast(params.init); + gmm_params.seed = params.seed; + return gmm_params; +} + +// Total number of elements of a (device) DLPack tensor, used for the +// covariance-shaped buffers whose rank depends on the covariance type. +int64_t tensor_numel(DLManagedTensor* tensor, const char* name) +{ + auto dl = tensor->dl_tensor; + if (!cuvs::core::is_dlpack_device_compatible(dl)) { + RAFT_FAIL("%s must be on device memory", name); + } + int64_t total = 1; + for (int i = 0; i < dl.ndim; ++i) + total *= dl.shape[i]; + return total; +} + +// The flat (covariance-shaped) buffers are reinterpret_cast to T, so their +// element type must match the dtype the call was dispatched on (X's dtype). +// from_dlpack enforces this for matrix/vector views; do the same here. +template +void check_flat_dtype(DLManagedTensor* tensor, const char* name) +{ + auto dt = tensor->dl_tensor.dtype; + RAFT_EXPECTS(dt.code == kDLFloat && dt.bits == sizeof(T) * 8, + "%s must be a %d-bit float buffer matching the dtype of X", + name, + static_cast(sizeof(T) * 8)); +} + +template +raft::device_vector_view flat_device_view(DLManagedTensor* tensor, const char* name) +{ + check_flat_dtype(tensor, name); + return raft::make_device_vector_view(reinterpret_cast(tensor->dl_tensor.data), + tensor_numel(tensor, name)); +} + +template +raft::device_vector_view flat_device_view_const(DLManagedTensor* tensor, + const char* name) +{ + check_flat_dtype(tensor, name); + return raft::make_device_vector_view( + reinterpret_cast(tensor->dl_tensor.data), tensor_numel(tensor, name)); +} + +template +void _fit(cuvsResources_t res, + const cuvsGMMParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* weights_tensor, + DLManagedTensor* means_tensor, + DLManagedTensor* covariances_tensor, + DLManagedTensor* precisions_chol_tensor, + DLManagedTensor* precisions_tensor, + DLManagedTensor* labels_tensor, + double* lower_bound, + int* n_iter, + bool* converged, + bool warm_start) +{ + auto res_ptr = reinterpret_cast(res); + if (!cuvs::core::is_dlpack_device_compatible(X_tensor->dl_tensor)) { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + using const_matrix_type = raft::device_matrix_view; + using matrix_type = raft::device_matrix_view; + using labels_type = raft::device_vector_view; + + T lower_bound_temp; + int n_iter_temp; + bool converged_temp; + + auto gmm_params = convert_params(params); + cuvs::cluster::gmm::fit(*res_ptr, + gmm_params, + cuvs::core::from_dlpack(X_tensor), + flat_device_view(weights_tensor, "weights"), + cuvs::core::from_dlpack(means_tensor), + flat_device_view(covariances_tensor, "covariances"), + flat_device_view(precisions_chol_tensor, "precisions_chol"), + flat_device_view(precisions_tensor, "precisions"), + cuvs::core::from_dlpack(labels_tensor), + raft::make_host_scalar_view(&lower_bound_temp), + raft::make_host_scalar_view(&n_iter_temp), + raft::make_host_scalar_view(&converged_temp), + warm_start); + + *lower_bound = lower_bound_temp; + *n_iter = n_iter_temp; + *converged = converged_temp; +} + +template +void _predict(cuvsResources_t res, + const cuvsGMMParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* weights_tensor, + DLManagedTensor* means_tensor, + DLManagedTensor* precisions_chol_tensor, + DLManagedTensor* labels_tensor) +{ + auto res_ptr = reinterpret_cast(res); + if (!cuvs::core::is_dlpack_device_compatible(X_tensor->dl_tensor)) { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + using const_matrix_type = raft::device_matrix_view; + using labels_type = raft::device_vector_view; + + auto gmm_params = convert_params(params); + cuvs::cluster::gmm::predict(*res_ptr, + gmm_params, + cuvs::core::from_dlpack(X_tensor), + flat_device_view_const(weights_tensor, "weights"), + cuvs::core::from_dlpack(means_tensor), + flat_device_view_const(precisions_chol_tensor, "precisions_chol"), + cuvs::core::from_dlpack(labels_tensor)); +} + +template +void _predict_proba(cuvsResources_t res, + const cuvsGMMParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* weights_tensor, + DLManagedTensor* means_tensor, + DLManagedTensor* precisions_chol_tensor, + DLManagedTensor* resp_tensor) +{ + auto res_ptr = reinterpret_cast(res); + if (!cuvs::core::is_dlpack_device_compatible(X_tensor->dl_tensor)) { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + using const_matrix_type = raft::device_matrix_view; + using matrix_type = raft::device_matrix_view; + + auto gmm_params = convert_params(params); + cuvs::cluster::gmm::predict_proba( + *res_ptr, + gmm_params, + cuvs::core::from_dlpack(X_tensor), + flat_device_view_const(weights_tensor, "weights"), + cuvs::core::from_dlpack(means_tensor), + flat_device_view_const(precisions_chol_tensor, "precisions_chol"), + cuvs::core::from_dlpack(resp_tensor)); +} + +template +void _score_samples(cuvsResources_t res, + const cuvsGMMParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* weights_tensor, + DLManagedTensor* means_tensor, + DLManagedTensor* precisions_chol_tensor, + DLManagedTensor* log_prob_norm_tensor) +{ + auto res_ptr = reinterpret_cast(res); + if (!cuvs::core::is_dlpack_device_compatible(X_tensor->dl_tensor)) { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + using const_matrix_type = raft::device_matrix_view; + + auto gmm_params = convert_params(params); + cuvs::cluster::gmm::score_samples( + *res_ptr, + gmm_params, + cuvs::core::from_dlpack(X_tensor), + flat_device_view_const(weights_tensor, "weights"), + cuvs::core::from_dlpack(means_tensor), + flat_device_view_const(precisions_chol_tensor, "precisions_chol"), + flat_device_view(log_prob_norm_tensor, "log_prob_norm")); +} + +} // namespace + +extern "C" cuvsError_t cuvsGMMParamsCreate(cuvsGMMParams_t* params) +{ + return cuvs::core::translate_exceptions([=] { + cuvs::cluster::gmm::params cpp_params; + *params = + new cuvsGMMParams{.n_components = cpp_params.n_components, + .covariance_type = static_cast(cpp_params.cov_type), + .tol = cpp_params.tol, + .reg_covar = cpp_params.reg_covar, + .max_iter = cpp_params.max_iter, + .n_init = cpp_params.n_init, + .init = static_cast(cpp_params.init), + .seed = cpp_params.seed}; + }); +} + +extern "C" cuvsError_t cuvsGMMParamsDestroy(cuvsGMMParams_t params) +{ + return cuvs::core::translate_exceptions([=] { delete params; }); +} + +extern "C" cuvsError_t cuvsGMMFit(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* covariances, + DLManagedTensor* precisions_chol, + DLManagedTensor* precisions, + DLManagedTensor* labels, + double* lower_bound, + int* n_iter, + bool* converged, + bool warm_start) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _fit(res, + *params, + X, + weights, + means, + covariances, + precisions_chol, + precisions, + labels, + lower_bound, + n_iter, + converged, + warm_start); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _fit(res, + *params, + X, + weights, + means, + covariances, + precisions_chol, + precisions, + labels, + lower_bound, + n_iter, + converged, + warm_start); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsGMMPredict(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* labels) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _predict(res, *params, X, weights, means, precisions_chol, labels); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _predict(res, *params, X, weights, means, precisions_chol, labels); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsGMMPredictProba(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* resp) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _predict_proba(res, *params, X, weights, means, precisions_chol, resp); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _predict_proba(res, *params, X, weights, means, precisions_chol, resp); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsGMMScoreSamples(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* log_prob_norm) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _score_samples(res, *params, X, weights, means, precisions_chol, log_prob_norm); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _score_samples(res, *params, X, weights, means, precisions_chol, log_prob_norm); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} diff --git a/c/tests/CMakeLists.txt b/c/tests/CMakeLists.txt index f1cff7824e..d481f89879 100644 --- a/c/tests/CMakeLists.txt +++ b/c/tests/CMakeLists.txt @@ -79,6 +79,7 @@ ConfigureTest( NAME DISTANCE_C_TEST PATH distance/run_pairwise_distance_c.c distance/pairwise_distance_c.cu ) ConfigureTest(NAME KMEANS_C_TEST PATH cluster/kmeans_c.cu) +ConfigureTest(NAME GMM_C_TEST PATH cluster/gmm_c.cu) ConfigureTest(NAME BRUTEFORCE_C_TEST PATH neighbors/run_brute_force_c.c neighbors/brute_force_c.cu) ConfigureTest(NAME IVF_FLAT_C_TEST PATH neighbors/run_ivf_flat_c.c neighbors/ann_ivf_flat_c.cu) ConfigureTest(NAME IVF_PQ_C_TEST PATH neighbors/run_ivf_pq_c.c neighbors/ann_ivf_pq_c.cu) diff --git a/c/tests/cluster/gmm_c.cu b/c/tests/cluster/gmm_c.cu new file mode 100644 index 0000000000..44f21ac142 --- /dev/null +++ b/c/tests/cluster/gmm_c.cu @@ -0,0 +1,271 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "test_utils.cuh" + +#include +#include + +#include +#include +#include +#include + +#include "../../src/core/interop.hpp" +#include +#include + +#include +#include +#include + +namespace { + +constexpr int64_t kNSamples = 8; +constexpr int64_t kNFeatures = 2; +constexpr int kNComponents = 2; + +// Two tight, well-separated clusters of four points each. +float kDataset[kNSamples][kNFeatures] = { + {1.0f, 1.0f}, + {1.0f, 2.0f}, + {2.0f, 1.0f}, + {2.0f, 2.0f}, + {10.0f, 10.0f}, + {10.0f, 11.0f}, + {11.0f, 10.0f}, + {11.0f, 11.0f}, +}; + +void test_fit_predict() +{ + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + int64_t cn = (int64_t)kNComponents * kNFeatures * kNFeatures; // FULL + + rmm::device_uvector dataset_d(kNSamples * kNFeatures, stream); + rmm::device_uvector weights_d(kNComponents, stream); + rmm::device_uvector means_d(kNComponents * kNFeatures, stream); + rmm::device_uvector covs_d(cn, stream); + rmm::device_uvector pchol_d(cn, stream); + rmm::device_uvector precs_d(cn, stream); + rmm::device_uvector labels_d(kNSamples, stream); + rmm::device_uvector labels2_d(kNSamples, stream); + rmm::device_uvector resp_d(kNSamples * kNComponents, stream); + rmm::device_uvector logp_d(kNSamples, stream); + + raft::copy( + dataset_d.data(), reinterpret_cast(kDataset), kNSamples * kNFeatures, stream); + + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + // Bind the resource to the same stream the data was copied on; raft::copy above + // is async and cuvsGMMFit would otherwise race it on the resource's own stream. + ASSERT_EQ(cuvsStreamSet(res, stream), CUVS_SUCCESS); + + cuvsGMMParams_t params; + ASSERT_EQ(cuvsGMMParamsCreate(¶ms), CUVS_SUCCESS); + params->n_components = kNComponents; + params->covariance_type = CUVS_GMM_COVARIANCE_FULL; + params->max_iter = 100; + params->seed = 1234ULL; + + auto to_dl_mat = [&](float* p, int64_t r, int64_t c, DLManagedTensor* t) { + cuvs::core::to_dlpack(raft::make_device_matrix_view(p, r, c), t); + }; + auto to_dl_vec = [&](float* p, int64_t n, DLManagedTensor* t) { + cuvs::core::to_dlpack(raft::make_device_vector_view(p, n), t); + }; + + DLManagedTensor X_t{}, w_t{}, m_t{}, cov_t{}, pc_t{}, pr_t{}, lab_t{}, lab2_t{}, resp_t{}, lp_t{}; + to_dl_mat(dataset_d.data(), kNSamples, kNFeatures, &X_t); + to_dl_vec(weights_d.data(), kNComponents, &w_t); + to_dl_mat(means_d.data(), kNComponents, kNFeatures, &m_t); + to_dl_vec(covs_d.data(), cn, &cov_t); + to_dl_vec(pchol_d.data(), cn, &pc_t); + to_dl_vec(precs_d.data(), cn, &pr_t); + cuvs::core::to_dlpack(raft::make_device_vector_view(labels_d.data(), kNSamples), + &lab_t); + cuvs::core::to_dlpack( + raft::make_device_vector_view(labels2_d.data(), kNSamples), &lab2_t); + to_dl_mat(resp_d.data(), kNSamples, kNComponents, &resp_t); + to_dl_vec(logp_d.data(), kNSamples, &lp_t); + + double lower_bound = 0.0; + int n_iter = -1; + bool converged = false; + + ASSERT_EQ(cuvsGMMFit(res, + params, + &X_t, + &w_t, + &m_t, + &cov_t, + &pc_t, + &pr_t, + &lab_t, + &lower_bound, + &n_iter, + &converged, + /* warm_start */ false), + CUVS_SUCCESS); + EXPECT_GT(n_iter, 0); + + ASSERT_EQ(cuvsGMMPredict(res, params, &X_t, &w_t, &m_t, &pc_t, &lab2_t), CUVS_SUCCESS); + ASSERT_EQ(cuvsGMMPredictProba(res, params, &X_t, &w_t, &m_t, &pc_t, &resp_t), CUVS_SUCCESS); + ASSERT_EQ(cuvsGMMScoreSamples(res, params, &X_t, &w_t, &m_t, &pc_t, &lp_t), CUVS_SUCCESS); + + std::vector h_labels(kNSamples), h_labels2(kNSamples); + raft::copy(h_labels.data(), labels_d.data(), kNSamples, stream); + raft::copy(h_labels2.data(), labels2_d.data(), kNSamples, stream); + std::vector h_resp(kNSamples * kNComponents); + raft::copy(h_resp.data(), resp_d.data(), kNSamples * kNComponents, stream); + std::vector h_logp(kNSamples); + raft::copy(h_logp.data(), logp_d.data(), kNSamples, stream); + raft::resource::sync_stream(handle, stream); + + // score_samples returns a finite per-sample log-likelihood. + for (int i = 0; i < kNSamples; ++i) + EXPECT_TRUE(std::isfinite(h_logp[i])); + + // fit and predict agree, and both recover the two-cluster partition (the two + // halves get the same label within each half, different across halves). + for (int i = 0; i < kNSamples; ++i) + EXPECT_EQ(h_labels[i], h_labels2[i]); + EXPECT_EQ(h_labels[0], h_labels[1]); + EXPECT_EQ(h_labels[0], h_labels[2]); + EXPECT_EQ(h_labels[0], h_labels[3]); + EXPECT_EQ(h_labels[4], h_labels[5]); + EXPECT_NE(h_labels[0], h_labels[4]); + + // responsibilities normalized per row. + for (int i = 0; i < kNSamples; ++i) { + float s = h_resp[i * kNComponents] + h_resp[i * kNComponents + 1]; + EXPECT_NEAR(s, 1.0f, 1e-3f); + } + + lp_t.deleter(&lp_t); + resp_t.deleter(&resp_t); + lab2_t.deleter(&lab2_t); + lab_t.deleter(&lab_t); + pr_t.deleter(&pr_t); + pc_t.deleter(&pc_t); + cov_t.deleter(&cov_t); + m_t.deleter(&m_t); + w_t.deleter(&w_t); + X_t.deleter(&X_t); + + ASSERT_EQ(cuvsGMMParamsDestroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +// Exercises the float64 dispatch and the DIAG flat-buffer sizing (K*d) through +// the C boundary, plus cuvsGMMScoreSamples on the alternate dtype. +void test_fit_score_double_diag() +{ + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + int64_t cn = (int64_t)kNComponents * kNFeatures; // DIAG + + std::vector h_X(kNSamples * kNFeatures); + for (int i = 0; i < kNSamples; ++i) + for (int j = 0; j < kNFeatures; ++j) + h_X[i * kNFeatures + j] = static_cast(kDataset[i][j]); + + rmm::device_uvector X_d(kNSamples * kNFeatures, stream); + rmm::device_uvector weights_d(kNComponents, stream); + rmm::device_uvector means_d(kNComponents * kNFeatures, stream); + rmm::device_uvector covs_d(cn, stream); + rmm::device_uvector pchol_d(cn, stream); + rmm::device_uvector precs_d(cn, stream); + rmm::device_uvector labels_d(kNSamples, stream); + rmm::device_uvector logp_d(kNSamples, stream); + raft::copy(X_d.data(), h_X.data(), kNSamples * kNFeatures, stream); + + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + // Bind the resource to the same stream the data was copied on; raft::copy above + // is async and cuvsGMMFit would otherwise race it on the resource's own stream. + ASSERT_EQ(cuvsStreamSet(res, stream), CUVS_SUCCESS); + cuvsGMMParams_t params; + ASSERT_EQ(cuvsGMMParamsCreate(¶ms), CUVS_SUCCESS); + params->n_components = kNComponents; + params->covariance_type = CUVS_GMM_COVARIANCE_DIAG; + params->max_iter = 100; + params->seed = 1234ULL; + + auto dmat = [&](double* p, int64_t r, int64_t c, DLManagedTensor* t) { + cuvs::core::to_dlpack(raft::make_device_matrix_view(p, r, c), t); + }; + auto dvec = [&](double* p, int64_t n, DLManagedTensor* t) { + cuvs::core::to_dlpack(raft::make_device_vector_view(p, n), t); + }; + + DLManagedTensor X_t{}, w_t{}, m_t{}, cov_t{}, pc_t{}, pr_t{}, lab_t{}, lp_t{}; + dmat(X_d.data(), kNSamples, kNFeatures, &X_t); + dvec(weights_d.data(), kNComponents, &w_t); + dmat(means_d.data(), kNComponents, kNFeatures, &m_t); + dvec(covs_d.data(), cn, &cov_t); + dvec(pchol_d.data(), cn, &pc_t); + dvec(precs_d.data(), cn, &pr_t); + cuvs::core::to_dlpack(raft::make_device_vector_view(labels_d.data(), kNSamples), + &lab_t); + dvec(logp_d.data(), kNSamples, &lp_t); + + double lower_bound = 0.0; + int n_iter = -1; + bool converged = false; + ASSERT_EQ(cuvsGMMFit(res, + params, + &X_t, + &w_t, + &m_t, + &cov_t, + &pc_t, + &pr_t, + &lab_t, + &lower_bound, + &n_iter, + &converged, + /* warm_start */ false), + CUVS_SUCCESS); + EXPECT_GT(n_iter, 0); + ASSERT_EQ(cuvsGMMScoreSamples(res, params, &X_t, &w_t, &m_t, &pc_t, &lp_t), CUVS_SUCCESS); + + std::vector h_logp(kNSamples); + raft::copy(h_logp.data(), logp_d.data(), kNSamples, stream); + raft::resource::sync_stream(handle, stream); + for (int i = 0; i < kNSamples; ++i) + EXPECT_TRUE(std::isfinite(h_logp[i])); + + lp_t.deleter(&lp_t); + lab_t.deleter(&lab_t); + pr_t.deleter(&pr_t); + pc_t.deleter(&pc_t); + cov_t.deleter(&cov_t); + m_t.deleter(&m_t); + w_t.deleter(&w_t); + X_t.deleter(&X_t); + ASSERT_EQ(cuvsGMMParamsDestroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +} // namespace + +TEST(GMMC, FitPredict) { test_fit_predict(); } + +TEST(GMMC, FitScoreDoubleDiag) { test_fit_score_double_diag(); } + +TEST(GMMC, ParamsCreateDestroy) +{ + cuvsGMMParams_t params = nullptr; + ASSERT_EQ(cuvsGMMParamsCreate(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_components, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsGMMParamsDestroy(params), CUVS_SUCCESS); +} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 89cbadfcfc..1e36e6f14f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1303,6 +1303,8 @@ if(NOT BUILD_CPU_ONLY) cuvs_objs OBJECT src/cluster/detail/minClusterDistanceCompute.cu src/cluster/agglomerative.cu + src/cluster/gmm_double.cu + src/cluster/gmm_float.cu src/cluster/kmeans_cluster_cost.cu src/cluster/kmeans_fit_mg_float.cu src/cluster/kmeans_fit_mg_double.cu diff --git a/cpp/include/cuvs/cluster/gmm.hpp b/cpp/include/cuvs/cluster/gmm.hpp new file mode 100644 index 0000000000..63c4f2b6fd --- /dev/null +++ b/cpp/include/cuvs/cluster/gmm.hpp @@ -0,0 +1,334 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include +#include + +#include + +#include + +namespace CUVS_EXPORT cuvs { +namespace cluster { +namespace gmm { + +/** + * @defgroup gmm_params Gaussian mixture hyperparameters + * @{ + */ + +/** Covariance parameterization of the mixture components. */ +enum class covariance_type { FULL = 0, TIED = 1, DIAG = 2, SPHERICAL = 3 }; + +/** Strategy used to initialize the responsibilities before EM. */ +enum class init_method { + /** Run k-means (itself seeded with k-means++) and use the hard labels. */ + KMeans = 0, + /** Use the k-means++ seeding labels directly. */ + KMeansPlusPlus = 1, + /** Random per-sample-normalized responsibilities. */ + Random = 2, + /** Pick n_components samples at random as one-hot responsibilities. */ + RandomFromData = 3, +}; + +/** Hyper-parameters for the Gaussian mixture EM solver. */ +struct params { + /** The number of mixture components. Default: 1. */ + int n_components = 1; + /** Covariance parameterization of the mixture components. Default: FULL. */ + covariance_type cov_type = covariance_type::FULL; + /** Convergence threshold on the change of the per-sample average + * log-likelihood (lower bound). Default: 1e-3. */ + double tol = 1e-3; + /** Non-negative regularization added to the diagonal of covariance. + * Default: 1e-6. */ + double reg_covar = 1e-6; + /** Maximum number of EM iterations for a single run. Default: 100. */ + int max_iter = 100; + /** Number of initializations to perform; the best result is kept. + * Default: 1. */ + int n_init = 1; + /** Strategy used to initialize the responsibilities before EM. + * Default: KMeans. */ + init_method init = init_method::KMeans; + /** Seed to the random number generator. Default: 0. */ + uint64_t seed = 0; +}; + +/** + * @} + */ + +/** + * @defgroup gmm Gaussian mixture model APIs + * @{ + * + * The covariance-shaped buffers (``covariances``, ``precisions_chol``, + * ``precisions``) are passed as flat device vectors because their logical + * shape depends on ``params::cov_type``. With ``K = n_components`` and + * ``d = n_features`` the expected lengths are (row-major): + * + * - ``FULL``: K * d * d (logically (K, d, d)) + * - ``TIED``: d * d (logically (d, d)) + * - ``DIAG``: K * d (logically (K, d)) + * - ``SPHERICAL``: K (logically (K,)) + * + * For ``FULL``/``TIED``, ``precisions_chol`` holds the upper-triangular + * factor ``U`` of each precision matrix (precision ``= U @ Uᵀ``); for + * ``DIAG``/``SPHERICAL`` it holds reciprocal standard deviations. These + * conventions match scikit-learn's ``GaussianMixture``. + */ + +/** + * @brief Fit a Gaussian mixture with the EM algorithm. + * + * Runs ``params.n_init`` random restarts (unless @p warm_start is true) and + * keeps the parameters with the largest lower bound. Writes the fitted + * ``weights``, ``means``, ``covariances``, ``precisions_chol`` and + * ``precisions``, the per-sample hard ``labels`` (argmax of the final + * responsibilities), and the scalar ``lower_bound`` / ``n_iter`` / + * ``converged`` diagnostics. + * + * When @p warm_start is true the incoming ``weights`` / ``means`` / + * ``covariances`` are used as the single initialization and ``params.n_init`` + * is ignored. + * + * @code{.cpp} + * #include + * #include + * ... + * raft::resources handle; + * cuvs::cluster::gmm::params params; + * params.n_components = 3; + * + * int64_t K = params.n_components, d = X.extent(1); + * auto weights = raft::make_device_vector(handle, K); + * auto means = raft::make_device_matrix(handle, K, d); + * auto covs = raft::make_device_vector(handle, K * d * d); + * auto pchol = raft::make_device_vector(handle, K * d * d); + * auto precs = raft::make_device_vector(handle, K * d * d); + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * float lower_bound; + * int n_iter; + * bool converged; + * + * gmm::fit(handle, params, X, weights.view(), means.view(), covs.view(), + * pchol.view(), precs.view(), labels.view(), + * raft::make_host_scalar_view(&lower_bound), + * raft::make_host_scalar_view(&n_iter), + * raft::make_host_scalar_view(&converged)); + * @endcode + * + * @param[in] handle The raft resources handle. + * @param[in] params Hyper-parameters of the EM solver. + * @param[in] X Training data, row-major. + * [dim = n_samples x n_features] + * @param[inout] weights Mixture weights. [len = n_components] + * @param[inout] means Component means, row-major. + * [dim = n_components x n_features] + * @param[inout] covariances Component covariances, flat. Length depends on + * cov_type (K=n_components, d=n_features): FULL + * K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] precisions_chol Precision Cholesky factors, same flat layout as + * covariances. FULL/TIED hold the upper-triangular + * factor U (precision = U @ Uᵀ); DIAG/SPHERICAL + * hold reciprocal standard deviations. + * @param[out] precisions Precision matrices, same flat layout as + * covariances. + * @param[out] labels Hard component assignment per sample. + * [len = n_samples] + * @param[out] lower_bound Per-sample average log-likelihood of the best + * fit. + * @param[out] n_iter Number of EM iterations of the best fit. + * @param[out] converged Whether the best fit converged within + * ``params.tol``. + * @param[in] warm_start Use the incoming weights/means/covariances as + * the single initialization. + */ +template +void fit(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view covariances, + raft::device_vector_view precisions_chol, + raft::device_vector_view precisions, + raft::device_vector_view labels, + raft::host_scalar_view lower_bound, + raft::host_scalar_view n_iter, + raft::host_scalar_view converged, + bool warm_start = false); + +extern template void fit(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view covariances, + raft::device_vector_view precisions_chol, + raft::device_vector_view precisions, + raft::device_vector_view labels, + raft::host_scalar_view lower_bound, + raft::host_scalar_view n_iter, + raft::host_scalar_view converged, + bool warm_start); + +extern template void fit(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view covariances, + raft::device_vector_view precisions_chol, + raft::device_vector_view precisions, + raft::device_vector_view labels, + raft::host_scalar_view lower_bound, + raft::host_scalar_view n_iter, + raft::host_scalar_view converged, + bool warm_start); + +/** + * @brief Hard component labels (argmax responsibility) for new data. + * + * @param[in] handle The raft resources handle. + * @param[in] params Fit hyper-parameters; only n_components and + * cov_type are consulted at inference time. + * @param[in] X Data to assign, row-major. + * [dim = n_samples x n_features] + * @param[in] weights Fitted mixture weights. [len = n_components] + * @param[in] means Fitted component means. + * [dim = n_components x n_features] + * @param[in] precisions_chol Fitted precision Cholesky factors, flat. Length + * by cov_type (K=n_components, d=n_features): FULL + * K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] labels Hard component assignment per sample. + * [len = n_samples] + */ +template +void predict(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view labels); + +extern template void predict(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view labels); + +extern template void predict( + raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view labels); + +/** + * @brief Posterior responsibilities for new data. + * + * @param[in] handle The raft resources handle. + * @param[in] params Fit hyper-parameters; only n_components and + * cov_type are consulted at inference time. + * @param[in] X Data to evaluate, row-major. + * [dim = n_samples x n_features] + * @param[in] weights Fitted mixture weights. [len = n_components] + * @param[in] means Fitted component means. + * [dim = n_components x n_features] + * @param[in] precisions_chol Fitted precision Cholesky factors, flat. Length + * by cov_type (K=n_components, d=n_features): FULL + * K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] resp Posterior probability of each component for + * each sample, row-major. + * [dim = n_samples x n_components] + */ +template +void predict_proba(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_matrix_view resp); + +extern template void predict_proba( + raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_matrix_view resp); + +extern template void predict_proba( + raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_matrix_view resp); + +/** + * @brief Per-sample log-likelihood log p(x_i) for new data. + * + * @param[in] handle The raft resources handle. + * @param[in] params Fit hyper-parameters; only n_components and + * cov_type are consulted at inference time. + * @param[in] X Data to evaluate, row-major. + * [dim = n_samples x n_features] + * @param[in] weights Fitted mixture weights. [len = n_components] + * @param[in] means Fitted component means. + * [dim = n_components x n_features] + * @param[in] precisions_chol Fitted precision Cholesky factors, flat. Length + * by cov_type (K=n_components, d=n_features): FULL + * K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. + * @param[out] log_prob_norm Log-likelihood of each sample under the model. + * [len = n_samples] + */ +template +void score_samples(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view log_prob_norm); + +extern template void score_samples( + raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view log_prob_norm); + +extern template void score_samples( + raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view log_prob_norm); + +/** + * @} + */ + +} // namespace gmm +} // namespace cluster +} // namespace CUVS_EXPORT cuvs diff --git a/cpp/src/cluster/gmm.cuh b/cpp/src/cluster/gmm.cuh new file mode 100644 index 0000000000..bd4272e8b9 --- /dev/null +++ b/cpp/src/cluster/gmm.cuh @@ -0,0 +1,177 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "gmm_impl.cuh" + +#include + +#include + +#include + +namespace cuvs::cluster::gmm { + +namespace { + +// Validate the shared (X, weights, means, precisions_chol) arguments and +// write (n, d, K) as ints (the internal kernels index with 32-bit math). +template +void check_common_args(const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + int& n, + int& d, + int& K) +{ + int64_t n64 = X.extent(0); + int64_t d64 = X.extent(1); + int64_t K64 = params.n_components; + RAFT_EXPECTS(n64 > 0 && d64 > 0, "X must be non-empty"); + RAFT_EXPECTS(K64 > 0, "n_components must be positive"); + RAFT_EXPECTS(n64 <= std::numeric_limits::max() && d64 <= std::numeric_limits::max(), + "gmm currently supports up to 2^31-1 samples / features"); + RAFT_EXPECTS(weights.extent(0) == K64, "weights must have n_components elements"); + RAFT_EXPECTS(means.extent(0) == K64 && means.extent(1) == d64, + "means must be of shape (n_components, n_features)"); + auto expected = detail::cov_elems(params.cov_type, (int)d64, (int)K64); + RAFT_EXPECTS((size_t)precisions_chol.extent(0) == expected, + "precisions_chol has the wrong number of elements for the covariance type"); + n = (int)n64; + d = (int)d64; + K = (int)K64; +} + +} // namespace + +template +void fit(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view covariances, + raft::device_vector_view precisions_chol, + raft::device_vector_view precisions, + raft::device_vector_view labels, + raft::host_scalar_view lower_bound, + raft::host_scalar_view n_iter, + raft::host_scalar_view converged, + bool warm_start) +{ + int n, d, K; + check_common_args( + params, + X, + raft::make_device_vector_view(weights.data_handle(), weights.extent(0)), + raft::make_device_matrix_view( + means.data_handle(), means.extent(0), means.extent(1)), + raft::make_device_vector_view(precisions_chol.data_handle(), + precisions_chol.extent(0)), + n, + d, + K); + auto expected = detail::cov_elems(params.cov_type, d, K); + RAFT_EXPECTS((size_t)covariances.extent(0) == expected, + "covariances has the wrong number of elements for the covariance type"); + RAFT_EXPECTS((size_t)precisions.extent(0) == expected, + "precisions has the wrong number of elements for the covariance type"); + RAFT_EXPECTS(labels.extent(0) == X.extent(0), "labels must have n_samples elements"); + RAFT_EXPECTS(params.n_init > 0, "n_init must be positive"); + RAFT_EXPECTS(params.max_iter >= 0, "max_iter must be non-negative"); + RAFT_EXPECTS(K <= n, "n_components must be <= n_samples"); + + detail::fit_impl(handle, + params, + X.data_handle(), + n, + d, + weights.data_handle(), + means.data_handle(), + covariances.data_handle(), + precisions_chol.data_handle(), + precisions.data_handle(), + labels.data_handle(), + *lower_bound.data_handle(), + *n_iter.data_handle(), + *converged.data_handle(), + warm_start); +} + +template +void predict(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view labels) +{ + int n, d, K; + check_common_args(params, X, weights, means, precisions_chol, n, d, K); + RAFT_EXPECTS(labels.extent(0) == X.extent(0), "labels must have n_samples elements"); + detail::predict_impl(handle, + params, + X.data_handle(), + n, + d, + weights.data_handle(), + means.data_handle(), + precisions_chol.data_handle(), + labels.data_handle()); +} + +template +void predict_proba(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_matrix_view resp) +{ + int n, d, K; + check_common_args(params, X, weights, means, precisions_chol, n, d, K); + RAFT_EXPECTS(resp.extent(0) == X.extent(0) && resp.extent(1) == (int64_t)K, + "resp must be of shape (n_samples, n_components)"); + detail::predict_proba_impl(handle, + params, + X.data_handle(), + n, + d, + weights.data_handle(), + means.data_handle(), + precisions_chol.data_handle(), + resp.data_handle()); +} + +template +void score_samples(raft::resources const& handle, + const params& params, + raft::device_matrix_view X, + raft::device_vector_view weights, + raft::device_matrix_view means, + raft::device_vector_view precisions_chol, + raft::device_vector_view log_prob_norm) +{ + int n, d, K; + check_common_args(params, X, weights, means, precisions_chol, n, d, K); + RAFT_EXPECTS(log_prob_norm.extent(0) == X.extent(0), + "log_prob_norm must have n_samples elements"); + detail::score_samples_impl(handle, + params, + X.data_handle(), + n, + d, + weights.data_handle(), + means.data_handle(), + precisions_chol.data_handle(), + log_prob_norm.data_handle()); +} + +} // namespace cuvs::cluster::gmm diff --git a/cpp/src/cluster/gmm_double.cu b/cpp/src/cluster/gmm_double.cu new file mode 100644 index 0000000000..95f81b936f --- /dev/null +++ b/cpp/src/cluster/gmm_double.cu @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "gmm.cuh" + +#include +#include + +namespace cuvs::cluster::gmm { + +template CUVS_EXPORT void fit(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_vector_view, + raft::device_vector_view, + raft::device_vector_view, + raft::host_scalar_view, + raft::host_scalar_view, + raft::host_scalar_view, + bool); + +template CUVS_EXPORT void predict(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_vector_view); + +template CUVS_EXPORT void predict_proba(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view); + +template CUVS_EXPORT void score_samples(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_vector_view); + +} // namespace cuvs::cluster::gmm diff --git a/cpp/src/cluster/gmm_float.cu b/cpp/src/cluster/gmm_float.cu new file mode 100644 index 0000000000..466a370bf5 --- /dev/null +++ b/cpp/src/cluster/gmm_float.cu @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "gmm.cuh" + +#include +#include + +namespace cuvs::cluster::gmm { + +template CUVS_EXPORT void fit(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_vector_view, + raft::device_vector_view, + raft::device_vector_view, + raft::host_scalar_view, + raft::host_scalar_view, + raft::host_scalar_view, + bool); + +template CUVS_EXPORT void predict(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_vector_view); + +template CUVS_EXPORT void predict_proba(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view); + +template CUVS_EXPORT void score_samples(raft::resources const&, + const params&, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_matrix_view, + raft::device_vector_view, + raft::device_vector_view); + +} // namespace cuvs::cluster::gmm diff --git a/cpp/src/cluster/gmm_impl.cuh b/cpp/src/cluster/gmm_impl.cuh new file mode 100644 index 0000000000..63f29ac071 --- /dev/null +++ b/cpp/src/cluster/gmm_impl.cuh @@ -0,0 +1,1610 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "gmm_kernels.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace cuvs::cluster::gmm::detail { + +constexpr int E_STEP_BLOCK = 64; +constexpr int E_STEP_LARGE64_TILE = 64; +constexpr int E_STEP_THREAD64_BLOCK = 512; +constexpr int NORMALIZE_BLOCK = 32; +constexpr int REDUCE_BLOCK = 256; +constexpr size_t DEFAULT_SMEM_LIMIT = 48 * 1024; + +inline size_t upper_tri_size(size_t d) { return (d * (d + 1)) / 2; } + +inline void cublas_check(cublasStatus_t status, const char* what) +{ + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error(std::string(what) + " failed with cuBLAS status " + + std::to_string(static_cast(status))); + } +} + +inline void cusolver_check(cusolverStatus_t status, const char* what) +{ + if (status != CUSOLVER_STATUS_SUCCESS) { + throw std::runtime_error(std::string(what) + " failed with cuSOLVER status " + + std::to_string(static_cast(status))); + } +} + +inline const char* precision_error_message() +{ + return "Fitting the mixture model failed because some components have ill-defined empirical " + "covariance (for instance caused by singleton or collapsed samples). Try to decrease the " + "number of components, increase reg_covar, or scale the input data."; +} + +// ----- cuBLAS gemm / gemv wrappers ----- +template +cublasStatus_t cublas_gemm(cublasHandle_t h, + cublasOperation_t ta, + cublasOperation_t tb, + int m, + int n, + int k, + const T* alpha, + const T* A, + int lda, + const T* B, + int ldb, + const T* beta, + T* C, + int ldc); +template <> +inline cublasStatus_t cublas_gemm(cublasHandle_t h, + cublasOperation_t ta, + cublasOperation_t tb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc) +{ + return cublasSgemm(h, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} +template <> +inline cublasStatus_t cublas_gemm(cublasHandle_t h, + cublasOperation_t ta, + cublasOperation_t tb, + int m, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + const double* beta, + double* C, + int ldc) +{ + return cublasDgemm(h, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +template +cublasStatus_t cublas_gemv(cublasHandle_t h, + cublasOperation_t trans, + int m, + int n, + const T* alpha, + const T* A, + int lda, + const T* x, + int incx, + const T* beta, + T* y, + int incy); +template <> +inline cublasStatus_t cublas_gemv(cublasHandle_t h, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const float* A, + int lda, + const float* x, + int incx, + const float* beta, + float* y, + int incy) +{ + return cublasSgemv(h, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} +template <> +inline cublasStatus_t cublas_gemv(cublasHandle_t h, + cublasOperation_t trans, + int m, + int n, + const double* alpha, + const double* A, + int lda, + const double* x, + int incx, + const double* beta, + double* y, + int incy) +{ + return cublasDgemv(h, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +// ----- cuSOLVER potrf / cuBLAS trsm wrappers ----- +template +cusolverStatus_t potrf_bufsize( + cusolverDnHandle_t h, cublasFillMode_t uplo, int nn, T* A, int lda, int* lwork); +template <> +inline cusolverStatus_t potrf_bufsize( + cusolverDnHandle_t h, cublasFillMode_t uplo, int nn, float* A, int lda, int* lwork) +{ + return cusolverDnSpotrf_bufferSize(h, uplo, nn, A, lda, lwork); +} +template <> +inline cusolverStatus_t potrf_bufsize( + cusolverDnHandle_t h, cublasFillMode_t uplo, int nn, double* A, int lda, int* lwork) +{ + return cusolverDnDpotrf_bufferSize(h, uplo, nn, A, lda, lwork); +} + +template +cusolverStatus_t potrf(cusolverDnHandle_t h, + cublasFillMode_t uplo, + int nn, + T* A, + int lda, + T* work, + int lwork, + int* devInfo); +template <> +inline cusolverStatus_t potrf(cusolverDnHandle_t h, + cublasFillMode_t uplo, + int nn, + float* A, + int lda, + float* work, + int lwork, + int* devInfo) +{ + return cusolverDnSpotrf(h, uplo, nn, A, lda, work, lwork, devInfo); +} +template <> +inline cusolverStatus_t potrf(cusolverDnHandle_t h, + cublasFillMode_t uplo, + int nn, + double* A, + int lda, + double* work, + int lwork, + int* devInfo) +{ + return cusolverDnDpotrf(h, uplo, nn, A, lda, work, lwork, devInfo); +} + +template +cublasStatus_t cublas_trsm(cublasHandle_t h, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const T* alpha, + const T* A, + int lda, + T* B, + int ldb); +template <> +inline cublasStatus_t cublas_trsm(cublasHandle_t h, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float* alpha, + const float* A, + int lda, + float* B, + int ldb) +{ + return cublasStrsm(h, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} +template <> +inline cublasStatus_t cublas_trsm(cublasHandle_t h, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double* alpha, + const double* A, + int lda, + double* B, + int ldb) +{ + return cublasDtrsm(h, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +template +cusolverStatus_t potrf_batched(cusolverDnHandle_t h, + cublasFillMode_t uplo, + int nn, + T* Aarray[], + int lda, + int* infoArray, + int batch); +template <> +inline cusolverStatus_t potrf_batched(cusolverDnHandle_t h, + cublasFillMode_t uplo, + int nn, + float* Aarray[], + int lda, + int* infoArray, + int batch) +{ + return cusolverDnSpotrfBatched(h, uplo, nn, Aarray, lda, infoArray, batch); +} +template <> +inline cusolverStatus_t potrf_batched(cusolverDnHandle_t h, + cublasFillMode_t uplo, + int nn, + double* Aarray[], + int lda, + int* infoArray, + int batch) +{ + return cusolverDnDpotrfBatched(h, uplo, nn, Aarray, lda, infoArray, batch); +} + +template +cublasStatus_t trsm_batched(cublasHandle_t h, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const T* alpha, + const T* const Aarray[], + int lda, + T* const Barray[], + int ldb, + int batch); +template <> +inline cublasStatus_t trsm_batched(cublasHandle_t h, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float* alpha, + const float* const Aarray[], + int lda, + float* const Barray[], + int ldb, + int batch) +{ + return cublasStrsmBatched( + h, side, uplo, trans, diag, m, n, alpha, Aarray, lda, Barray, ldb, batch); +} +template <> +inline cublasStatus_t trsm_batched(cublasHandle_t h, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double* alpha, + const double* const Aarray[], + int lda, + double* const Barray[], + int ldb, + int batch) +{ + return cublasDtrsmBatched( + h, side, uplo, trans, diag, m, n, alpha, Aarray, lda, Barray, ldb, batch); +} + +// Number of elements in the covariance / precision buffer for a covariance type. +inline size_t cov_elems(covariance_type ct, int d, int K) +{ + switch (ct) { + case covariance_type::FULL: return (size_t)K * d * d; + case covariance_type::TIED: return (size_t)d * d; + case covariance_type::DIAG: return (size_t)K * d; + case covariance_type::SPHERICAL: return (size_t)K; + } + return 0; +} + +// =========================================================================== +// E-step +// =========================================================================== +template +void launch_small_fixed(const T* X, + const T* weights, + const T* means, + const T* prec_chol, + const T* log_det, + int n, + int K, + int prec_pc, + T* log_prob, + dim3 grid, + dim3 block, + cudaStream_t stream) +{ + size_t shmem = (D + upper_tri_size(D)) * sizeof(T); + if (shmem > DEFAULT_SMEM_LIMIT) { + RAFT_CUDA_TRY(cudaFuncSetAttribute(detail::e_step_log_prob_small_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + (int)shmem)); + } + detail::e_step_log_prob_small_kernel<<>>( + X, weights, means, prec_chol, log_det, n, D, K, prec_pc, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void e_step(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + int K, + const T* weights, + const T* means, + const T* prec_chol, + const T* log_det, + T* log_prob, + T* resp, + T* log_prob_norm) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + covariance_type ct = params.cov_type; + + if (ct == covariance_type::FULL || (ct == covariance_type::TIED && d > 128)) { + int prec_pc = (ct == covariance_type::FULL) ? 1 : 0; + // Size-based solver selection: the fused + // shared-memory kernels are fastest for the moderate-d regime; wide + // feature counts (and float64 above 64) route through a cuBLAS E-step that + // forms (X - means_k) @ prec_chol_k with a GEMM per component. + bool use_cublas = (sizeof(T) == 4) ? (d >= 257) : (d > 64); + if (use_cublas) { + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + rmm::device_uvector centered((size_t)n * d, stream); + rmm::device_uvector y((size_t)n * d, stream); + T one = T(1), zero = T(0); + int threads = 256; + int center_blocks = (int)(((size_t)n * d + threads - 1) / threads); + int row_blocks = (n + threads - 1) / threads; + for (int k = 0; k < K; ++k) { + detail::e_step_center_kernel + <<>>(X, means, n, d, k, centered.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + const T* pc_k = prec_chol + (size_t)(prec_pc ? k : 0) * d * d; + // y = (X - means_k) @ prec_chol_k. cuBLAS is column-major: the row-major + // (n, d) centered buffer is a column-major (d, n) matrix, so this GEMM + // computes the column-major (d, n) result prec_chol_kᵀ_cm @ centered_cm, + // which read back row-major is exactly the (n, d) matrix y[row, j]. + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + d, + n, + d, + &one, + pc_k, + d, + centered.data(), + d, + &zero, + y.data(), + d), + "gemm(e_step_cublas)"); + detail::e_step_log_prob_from_y_kernel + <<>>(y.data(), weights, log_det, n, d, K, k, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } else if (d <= 64) { + dim3 block(E_STEP_BLOCK); + dim3 grid((n + E_STEP_BLOCK - 1) / E_STEP_BLOCK, K); + if (d == 16) { + launch_small_fixed( + X, weights, means, prec_chol, log_det, n, K, prec_pc, log_prob, grid, block, stream); + } else if (d == 32) { + launch_small_fixed( + X, weights, means, prec_chol, log_det, n, K, prec_pc, log_prob, grid, block, stream); + } else if (d == 50) { + launch_small_fixed( + X, weights, means, prec_chol, log_det, n, K, prec_pc, log_prob, grid, block, stream); + } else if (d == 64) { + launch_small_fixed( + X, weights, means, prec_chol, log_det, n, K, prec_pc, log_prob, grid, block, stream); + } else { + size_t shmem = ((size_t)d + upper_tri_size(d)) * sizeof(T); + if (shmem > DEFAULT_SMEM_LIMIT) { + RAFT_CUDA_TRY(cudaFuncSetAttribute(detail::e_step_log_prob_small_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + (int)shmem)); + } + detail::e_step_log_prob_small_kernel<<>>( + X, weights, means, prec_chol, log_det, n, d, K, prec_pc, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } else { + dim3 block(E_STEP_THREAD64_BLOCK); + dim3 grid((n + E_STEP_THREAD64_BLOCK - 1) / E_STEP_THREAD64_BLOCK, K); + size_t shmem = + ((size_t)E_STEP_LARGE64_TILE + (size_t)E_STEP_LARGE64_TILE * E_STEP_LARGE64_TILE) * + sizeof(T); + detail::e_step_log_prob_large_d_thread64_kernel + <<>>( + X, weights, means, prec_chol, log_det, n, d, K, prec_pc, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } else if (ct == covariance_type::DIAG || + ct == covariance_type::SPHERICAL) { // fast register-tiled log-prob + // (component-tiled in shared memory, so it scales to large K where the + // per-mode kernels can't fit the means). Writes log_prob[n,k]; normalize below. + rmm::device_uvector const_k(K, stream); + detail::fused_const_kernel + <<>>(weights, log_det, d, K, const_k.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + constexpr int FEAT = 32; + constexpr int CELL = (sizeof(T) == 4) ? 64 : 32; + int tpb = 256; + int gb = (n + tpb - 1) / tpb; + if (ct == covariance_type::DIAG) + detail::estep_tiled_kernel + <<>>( + X, means, prec_chol, const_k.data(), n, d, K, nullptr, nullptr, log_prob); + else + detail::estep_tiled_kernel + <<>>( + X, means, prec_chol, const_k.data(), n, d, K, nullptr, nullptr, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } else { // TIED, d <= 128: shared Cholesky U -> transform X̃ = X·U, μ̃ = means·U, + // then ‖Uᵀ(x-μ_k)‖² is Euclidean ‖x̃-μ̃_k‖² -> same register-tiled kernel. + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + rmm::device_uvector xt((size_t)n * d, stream); + rmm::device_uvector mut((size_t)K * d, stream); + rmm::device_uvector ones_pc(K, stream); + rmm::device_uvector const_k(K, stream); + thrust::fill(thrust::cuda::par.on(stream), ones_pc.data(), ones_pc.data() + K, T(1)); + T one = T(1), zero = T(0); + cublas_check( + cublas_gemm( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, d, n, d, &one, prec_chol, d, X, d, &zero, xt.data(), d), + "gemm(tied_X)"); + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + d, + K, + d, + &one, + prec_chol, + d, + means, + d, + &zero, + mut.data(), + d), + "gemm(tied_mu)"); + detail::fused_const_kernel + <<>>(weights, log_det, d, K, const_k.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + constexpr int FEAT = 32; + constexpr int CELL = (sizeof(T) == 4) ? 64 : 32; + int tpb = 256; + int gb = (n + tpb - 1) / tpb; + detail::estep_tiled_kernel + <<>>( + xt.data(), mut.data(), ones_pc.data(), const_k.data(), n, d, K, nullptr, nullptr, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + dim3 nb(NORMALIZE_BLOCK); + dim3 ng(n); + detail::e_step_normalize_kernel<<>>(log_prob, n, K, resp, log_prob_norm); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// =========================================================================== +// M-step (weights, means, covariances) + precision Cholesky + log-det +// =========================================================================== +template +struct MStepWorkspace { + rmm::device_uvector ones; + rmm::device_uvector N_k; + rmm::device_uvector num; + rmm::device_uvector num2; + rmm::device_uvector centered; + rmm::device_uvector Xsq; + rmm::device_uvector diag_var; + rmm::device_uvector scaled_means; + rmm::device_uvector XtX; + rmm::device_uvector Lbuf; + rmm::device_uvector potrf_work; + rmm::device_uvector devInfo; + rmm::device_uvector cov_work; // batched copy of covariances for potrfBatched + rmm::device_uvector dA_ptrs; // device array of K pointers into cov_work + rmm::device_uvector dB_ptrs; // device array of K pointers into precisions_chol + int lwork; + + MStepWorkspace(raft::resources const& handle, const params& params, int n, int d, int K) + : ones(0, raft::resource::get_cuda_stream(handle)), + N_k(K, raft::resource::get_cuda_stream(handle)), + num((size_t)K * d, raft::resource::get_cuda_stream(handle)), + num2(0, raft::resource::get_cuda_stream(handle)), + centered(0, raft::resource::get_cuda_stream(handle)), + Xsq(0, raft::resource::get_cuda_stream(handle)), + diag_var(0, raft::resource::get_cuda_stream(handle)), + scaled_means(0, raft::resource::get_cuda_stream(handle)), + XtX(0, raft::resource::get_cuda_stream(handle)), + Lbuf(0, raft::resource::get_cuda_stream(handle)), + potrf_work(0, raft::resource::get_cuda_stream(handle)), + devInfo(K, raft::resource::get_cuda_stream(handle)), + cov_work(0, raft::resource::get_cuda_stream(handle)), + dA_ptrs(0, raft::resource::get_cuda_stream(handle)), + dB_ptrs(0, raft::resource::get_cuda_stream(handle)), + lwork(0) + { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + covariance_type ct = params.cov_type; + ones.resize(n, stream); + thrust::fill(thrust::cuda::par.on(stream), ones.data(), ones.data() + n, T(1)); + if (ct == covariance_type::FULL) { + centered.resize((size_t)n * d, stream); + // batched precision-Cholesky scratch + cov_work.resize((size_t)K * d * d, stream); + dA_ptrs.resize(K, stream); + dB_ptrs.resize(K, stream); + } else if (ct == covariance_type::TIED) { + scaled_means.resize((size_t)K * d, stream); + XtX.resize((size_t)d * d, stream); + } else { // diag / spherical + Xsq.resize((size_t)n * d, stream); + num2.resize((size_t)K * d, stream); + if (ct == covariance_type::SPHERICAL) diag_var.resize((size_t)K * d, stream); + } + if (ct == covariance_type::TIED) { + Lbuf.resize((size_t)d * d, stream); + int lw = 0; + cusolver_check(potrf_bufsize(raft::resource::get_cusolver_dn_handle(handle), + CUBLAS_FILL_MODE_LOWER, + d, + Lbuf.data(), + d, + &lw), + "potrf_bufferSize"); + lwork = lw; + potrf_work.resize((size_t)lw, stream); + } + } +}; + +// Accumulate responsibility-weighted sufficient statistics for one batch into the +// workspace: N_k += Σr, num += Σr·x, plus the second-moment term (num2 += Σr·x² +// for diag/sph; covariances += Σr·x·xᵀ for full; XtX += Σx·xᵀ for tied). beta=0 +// starts fresh, beta=1 adds on — so fit can stream N in tiles, never holding (N,K). +template +void m_accumulate(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + int K, + const T* resp, + MStepWorkspace& ws, + T beta) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + covariance_type ct = params.cov_type; + T one = T(1); + + cublas_check( + cublas_gemv( + cublas, CUBLAS_OP_N, K, n, &one, resp, K, ws.ones.data(), 1, &beta, ws.N_k.data(), 1), + "gemv(N_k)"); + cublas_check( + cublas_gemm( + cublas, CUBLAS_OP_N, CUBLAS_OP_T, d, K, n, &one, X, d, resp, K, &beta, ws.num.data(), d), + "gemm(num)"); + + // FULL accumulates its centered covariance in a separate pass (needs means); + // here only N_k / num are gathered for it. + if (ct == covariance_type::TIED) { + cublas_check( + cublas_gemm( + cublas, CUBLAS_OP_N, CUBLAS_OP_T, d, d, n, &one, X, d, X, d, &beta, ws.XtX.data(), d), + "gemm(XtX)"); + } else if (ct == covariance_type::DIAG || ct == covariance_type::SPHERICAL) { + int threads = 256; + int blocks = (int)(((size_t)n * d + threads - 1) / threads); + detail::elementwise_square_kernel + <<>>(X, (size_t)n * d, ws.Xsq.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_T, + d, + K, + n, + &one, + ws.Xsq.data(), + d, + resp, + K, + &beta, + ws.num2.data(), + d), + "gemm(num2)"); + } +} + +// Turn accumulated stats into weights / means / covariances (uncentered recovery). +template +void m_finalize(raft::resources const& handle, + const params& params, + int n, + int d, + int K, + MStepWorkspace& ws, + T* weights, + T* means, + T* covariances) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + covariance_type ct = params.cov_type; + T one = T(1), zero = T(0); + T eps = std::numeric_limits::epsilon(); + + detail::m_step_finalize_means_kernel + <<>>(ws.N_k.data(), ws.num.data(), weights, means, eps, n, d, K); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // FULL: covariance is finalized after its separate centered pass; nothing here. + if (ct == covariance_type::TIED) { + rmm::device_uvector weighted_outer((size_t)d * d, stream); + detail::scale_rows_by_kernel + <<>>(means, ws.N_k.data(), eps, d, K, ws.scaled_means.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_T, + d, + d, + K, + &one, + ws.scaled_means.data(), + d, + means, + d, + &zero, + weighted_outer.data(), + d), + "gemm(weighted_outer)"); + std::vector h_Nk(K); + raft::copy(h_Nk.data(), ws.N_k.data(), K, stream); + raft::resource::sync_stream(handle); + double sum_nk = 0.0; + for (int k = 0; k < K; ++k) + sum_nk += (double)h_Nk[k] + 10.0 * (double)eps; + int total = d * d, threads = 256, blocks = (total + threads - 1) / threads; + detail::m_step_finalize_tied_kernel<<>>( + ws.XtX.data(), weighted_outer.data(), T(sum_nk), T(params.reg_covar), d, covariances); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } else if (ct == covariance_type::DIAG || ct == covariance_type::SPHERICAL) { + if (ct == covariance_type::DIAG) { + detail::m_step_finalize_diag_kernel<<>>( + ws.N_k.data(), ws.num2.data(), means, T(params.reg_covar), eps, d, K, covariances); + } else { + detail::m_step_finalize_diag_kernel<<>>( + ws.N_k.data(), ws.num2.data(), means, T(params.reg_covar), eps, d, K, ws.diag_var.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + detail::m_step_spherical_from_diag_kernel + <<>>(ws.diag_var.data(), d, K, covariances); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +// FULL covariance, pass 2: accumulate the centered moment Σ r·(x-mu)(x-mu)ᵀ for +// one batch into covariances (beta=0 fresh, beta=1 add). Centered (not Σr·x·xᵀ) +// to avoid the catastrophic cancellation of the uncentered form on data far from +// the origin. Requires the means from pass 1. +template +void m_cov_full_pass(raft::resources const& handle, + const T* X, + int n, + int d, + int K, + const T* resp, + const T* means, + T* covariances, + MStepWorkspace& ws, + T beta) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + T one = T(1); + int threads = 256; + int blocks = (int)(((size_t)n * d + threads - 1) / threads); + for (int k = 0; k < K; ++k) { + detail::weighted_center_kernel + <<>>(X, resp, means, n, d, K, k, ws.centered.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_T, + d, + d, + n, + &one, + ws.centered.data(), + d, + ws.centered.data(), + d, + &beta, + covariances + (size_t)k * d * d, + d), + "gemm(cov_full)"); + } +} + +// Finalize FULL covariances (divide by N_k, symmetrize, add reg_covar). +template +void m_finalize_cov_full(raft::resources const& handle, + const params& params, + int d, + int K, + MStepWorkspace& ws, + T* covariances) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + T eps = std::numeric_limits::epsilon(); + detail::m_step_finalize_cov_full_kernel + <<>>(ws.N_k.data(), covariances, T(params.reg_covar), eps, d, K); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// One d×d precision-Cholesky: potrf(LOWER) then trsm solving L X = I; the +// resulting column-major L^{-1} is, read row-major, the upper precision factor. +template +void precision_cholesky_one( + raft::resources const& handle, const T* cov, T* prec_chol, int d, MStepWorkspace& ws) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cusolverDnHandle_t solver = raft::resource::get_cusolver_dn_handle(handle); + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + cusolver_check(cusolverDnSetStream(solver, stream), "cusolverDnSetStream"); + + raft::copy(ws.Lbuf.data(), cov, (size_t)d * d, stream); + cusolver_check(potrf(solver, + CUBLAS_FILL_MODE_LOWER, + d, + ws.Lbuf.data(), + d, + ws.potrf_work.data(), + ws.lwork, + ws.devInfo.data()), + "potrf"); + int info = 0; + raft::copy(&info, ws.devInfo.data(), 1, stream); + raft::resource::sync_stream(handle); + if (info != 0) throw std::runtime_error(precision_error_message()); + + detail::set_identity_kernel<<>>(prec_chol, d); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + T one = T(1); + cublas_check(cublas_trsm(cublas, + CUBLAS_SIDE_LEFT, + CUBLAS_FILL_MODE_LOWER, + CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, + d, + d, + &one, + ws.Lbuf.data(), + d, + prec_chol, + d), + "trsm(precision_cholesky)"); +} + +// Batched precision-Cholesky for full covariance: one potrfBatched + one +// trsmBatched over all K components (single info sync), matching the batched +// approach a cupy/array-API solver uses, instead of K per-component calls each +// with a host sync. +template +void precision_cholesky_full_batched(raft::resources const& handle, + const T* covariances, + T* prec_chol, + int d, + int K, + MStepWorkspace& ws) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cusolverDnHandle_t solver = raft::resource::get_cusolver_dn_handle(handle); + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + cusolver_check(cusolverDnSetStream(solver, stream), "cusolverDnSetStream"); + + raft::copy(ws.cov_work.data(), covariances, (size_t)K * d * d, stream); + std::vector hA(K), hB(K); + for (int k = 0; k < K; ++k) { + hA[k] = ws.cov_work.data() + (size_t)k * d * d; + hB[k] = prec_chol + (size_t)k * d * d; + } + RAFT_CUDA_TRY( + cudaMemcpyAsync(ws.dA_ptrs.data(), hA.data(), sizeof(T*) * K, cudaMemcpyHostToDevice, stream)); + RAFT_CUDA_TRY( + cudaMemcpyAsync(ws.dB_ptrs.data(), hB.data(), sizeof(T*) * K, cudaMemcpyHostToDevice, stream)); + + cusolver_check( + potrf_batched(solver, CUBLAS_FILL_MODE_LOWER, d, ws.dA_ptrs.data(), d, ws.devInfo.data(), K), + "potrfBatched"); + detail::set_identity_batched_kernel + <<>>(prec_chol, d, K); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + T one = T(1); + cublas_check(trsm_batched(cublas, + CUBLAS_SIDE_LEFT, + CUBLAS_FILL_MODE_LOWER, + CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, + d, + d, + &one, + ws.dA_ptrs.data(), + d, + ws.dB_ptrs.data(), + d, + K), + "trsmBatched"); + + std::vector hinfo(K); + raft::copy(hinfo.data(), ws.devInfo.data(), K, stream); + raft::resource::sync_stream(handle); + for (int k = 0; k < K; ++k) + if (hinfo[k] != 0) throw std::runtime_error(precision_error_message()); +} + +// covariances -> precisions_chol and log_det, dispatched on covariance type. +template +void update_precisions(raft::resources const& handle, + const params& params, + const T* covariances, + int d, + int K, + T* prec_chol, + T* log_det, + MStepWorkspace& ws) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + covariance_type ct = params.cov_type; + if (ct == covariance_type::FULL) { + precision_cholesky_full_batched(handle, covariances, prec_chol, d, K, ws); + detail::log_det_full_kernel + <<>>(prec_chol, d, K, log_det); + } else if (ct == covariance_type::TIED) { + precision_cholesky_one(handle, covariances, prec_chol, d, ws); + detail::log_det_tied_kernel + <<>>(prec_chol, d, K, log_det); + } else { + size_t total = cov_elems(ct, d, K); + T min_var = thrust::reduce(thrust::cuda::par.on(stream), + covariances, + covariances + total, + std::numeric_limits::max(), + thrust::minimum()); + if (min_var <= T(0)) { + // a non-positive variance means the covariance is ill-defined + throw std::runtime_error(precision_error_message()); + } + detail::recip_sqrt_kernel + <<>>(covariances, total, prec_chol); + if (ct == covariance_type::DIAG) { + detail::log_det_diag_kernel + <<>>(prec_chol, d, K, log_det); + } else { + detail::log_det_spherical_kernel + <<>>(prec_chol, d, K, log_det); + } + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// precisions_ = prec_chol @ prec_cholᵀ (full/tied) or prec_chol^2 (diag/spherical). +template +void compute_precisions(raft::resources const& handle, + const params& params, + const T* prec_chol, + int d, + int K, + T* precisions) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + covariance_type ct = params.cov_type; + T one = T(1), zero = T(0); + if (ct == covariance_type::FULL) { + for (int k = 0; k < K; ++k) { + const T* U = prec_chol + (size_t)k * d * d; // row-major upper U + T* P = precisions + (size_t)k * d * d; + // P = U @ Uᵀ (row-major). In column-major: U_cm = Uᵀ (upper->lower). Use + // gemm(OP_T, OP_N, d, d, d, U, U) which yields U_cmᵀ @ U_cm = U @ Uᵀ. + cublas_check( + cublas_gemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, d, d, d, &one, U, d, U, d, &zero, P, d), + "gemm(precisions_full)"); + } + } else if (ct == covariance_type::TIED) { + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_T, + CUBLAS_OP_N, + d, + d, + d, + &one, + prec_chol, + d, + prec_chol, + d, + &zero, + precisions, + d), + "gemm(precisions_tied)"); + } else { + size_t total = cov_elems(ct, d, K); + detail::elementwise_square_kernel + <<>>(prec_chol, total, precisions); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +// k-means (or k-means++ seeding) hard labels for GMM initialization. Centroids +// are computed internally; only the per-sample labels (length n) are returned. +template +void kmeans_assign(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + int K, + uint64_t init_seed, + int* labels_out) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + cuvs::cluster::kmeans::params kp; + kp.n_clusters = K; + // KMeansPlusPlus: seeding labels only (max_iter=0, no Lloyd); KMeans: full Lloyd. + kp.max_iter = (params.init == init_method::KMeansPlusPlus) ? 0 : 300; + kp.tol = 1e-4; + kp.n_init = 1; + kp.init = cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus; + kp.rng_state = raft::random::RngState{init_seed, raft::random::GeneratorType::GenPhilox}; + rmm::device_uvector centroids((size_t)K * d, stream); + T inertia = T(0); + int km_iter = 0; + auto X_view = raft::make_device_matrix_view(X, n, d); + cuvs::cluster::kmeans::fit(handle, + kp, + X_view, + std::nullopt, + raft::make_device_matrix_view(centroids.data(), K, d), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&km_iter)); + cuvs::cluster::kmeans::predict( + handle, + kp, + X_view, + std::nullopt, + raft::make_device_matrix_view(centroids.data(), K, d), + raft::make_device_vector_view(labels_out, n), + true, + raft::make_host_scalar_view(&inertia)); +} + +// =========================================================================== +// EM driver +// =========================================================================== +template +T mean_device(raft::resources const& handle, const T* v, int n) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + T sum = thrust::reduce(thrust::cuda::par.on(stream), v, v + n, T(0), thrust::plus()); + return sum / T(n); +} + +template +void fit_impl(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + T* weights, + T* means, + T* covariances, + T* precisions_chol, + T* precisions, + int* labels, + T& lower_bound, + int& n_iter, + bool& converged, + bool warm_start) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + int K = params.n_components; + covariance_type ct = params.cov_type; + size_t cn = cov_elems(ct, d, K); + + // Stream the E+M over N-tiles: only a bounded (tile, K) responsibility block + // ever exists, never the full (N, K). + int tile = std::min(n, 65536); + + MStepWorkspace ws(handle, params, tile, d, K); + // normalize overwrites log_prob with resp in place (each element is read once and + // written once in the same expression), so one (tile, K) buffer serves both. + rmm::device_uvector resp((size_t)tile * K, stream); + rmm::device_uvector lpn(tile, stream); + rmm::device_uvector log_det(K, stream); + + rmm::device_uvector best_w(K, stream); + rmm::device_uvector best_m((size_t)K * d, stream); + rmm::device_uvector best_cov(cn, stream); + rmm::device_uvector best_pc(cn, stream); + + raft::random::RngState rng(params.seed, raft::random::GeneratorType::GenPhilox); + + // FULL's M-step runs the E-step twice (first to accumulate weights/means, then + // a second centered pass for the covariances). m_finalize overwrites + // weights/means between the two passes, so the covariance pass must re-run the + // E-step with the *pre-update* params to reproduce the same responsibilities + // the means/weights were built from; snapshot them here. Sized 0 (unused) for + // the other covariance types, which need no second pass. + rmm::device_uvector estep_w((ct == covariance_type::FULL) ? K : 0, stream); + rmm::device_uvector estep_m((ct == covariance_type::FULL) ? (size_t)K * d : 0, stream); + + // One E-step (+ optional M-step) over all data, streamed in tiles; returns the + // mean lower bound. The E-step weights/means are passed in so FULL's covariance + // pass can feed the snapshot above and reproduce the first pass's resp exactly. + auto estep_tile = [&](int t0, int nt, const T* e_weights, const T* e_means) { + e_step(handle, + params, + X + (size_t)t0 * d, + nt, + d, + K, + e_weights, + e_means, + precisions_chol, + log_det.data(), + resp.data(), + resp.data(), + lpn.data()); + }; + auto em_step = [&](bool do_mstep) -> double { + double lb_sum = 0.0; + for (int t0 = 0; t0 < n; t0 += tile) { + int nt = std::min(tile, n - t0); + estep_tile(t0, nt, weights, means); + if (do_mstep) + m_accumulate( + handle, params, X + (size_t)t0 * d, nt, d, K, resp.data(), ws, (t0 == 0) ? T(0) : T(1)); + lb_sum += (double)mean_device(handle, lpn.data(), nt) * nt; + } + if (do_mstep) { + if (ct == covariance_type::FULL) { + // Snapshot the params the E-step above used; m_finalize is about to + // overwrite weights/means and the covariance pass must reproduce this resp. + raft::copy(estep_w.data(), weights, K, stream); + raft::copy(estep_m.data(), means, (size_t)K * d, stream); + } + m_finalize(handle, params, n, d, K, ws, weights, means, covariances); + if (ct == covariance_type::FULL) { + for (int t0 = 0; t0 < n; t0 += tile) { + int nt = std::min(tile, n - t0); + estep_tile(t0, nt, estep_w.data(), estep_m.data()); + m_cov_full_pass(handle, + X + (size_t)t0 * d, + nt, + d, + K, + resp.data(), + means, + covariances, + ws, + (t0 == 0) ? T(0) : T(1)); + } + m_finalize_cov_full(handle, params, d, K, ws, covariances); + } + } + return lb_sum / n; + }; + + bool do_init = !warm_start; + int n_init = do_init ? params.n_init : 1; + double best_lb = -std::numeric_limits::infinity(); + int best_iter = 0; + bool best_conv = false; + bool have_best = false; + + for (int init = 0; init < n_init; ++init) { + if (do_init) { + uint64_t init_seed = params.seed + (uint64_t)init * 0x9E3779B97F4A7C15ULL; + init_method im = params.init; + // Per-sample assignment is O(n)/O(K), not (N, K): k-means labels, or the K + // random-from-data indices. The one-hot resp is then built one tile at a time. + rmm::device_uvector km_labels(0, stream); + rmm::device_uvector rfd_idx(0, stream); + if (im == init_method::KMeans || im == init_method::KMeansPlusPlus) { + km_labels.resize(n, stream); + kmeans_assign(handle, params, X, n, d, K, init_seed, km_labels.data()); + } else if (im == init_method::RandomFromData) { + rfd_idx.resize(K, stream); + auto idx = raft::random::excess_subsample(handle, rng, n, K); + raft::copy(rfd_idx.data(), idx.data_handle(), K, stream); + } + auto fill_init_resp = [&](int t0, int nt) { + if (im == init_method::Random) { + // Deterministic per-(init, tile): FULL fills the responsibilities twice + // (means pass, then centered-covariance pass) and both must match. A + // shared rng would advance between them and desync cov from the means. + raft::random::RngState tile_rng( + init_seed ^ (static_cast(t0) * 0x9E3779B97F4A7C15ULL), + raft::random::GeneratorType::GenPhilox); + raft::random::uniform(handle, tile_rng, resp.data(), (size_t)nt * K, T(0), T(1)); + detail::normalize_rows_kernel + <<>>(resp.data(), nt, K); + } else if (im == init_method::RandomFromData) { + RAFT_CUDA_TRY(cudaMemsetAsync(resp.data(), 0, sizeof(T) * (size_t)nt * K, stream)); + detail::scatter_onehot_tile_kernel<<>>( + rfd_idx.data(), K, t0, nt, resp.data()); + } else { + detail::labels_to_onehot_kernel<<>>( + km_labels.data() + t0, nt, K, resp.data()); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + for (int t0 = 0; t0 < n; t0 += tile) { + int nt = std::min(tile, n - t0); + fill_init_resp(t0, nt); + m_accumulate( + handle, params, X + (size_t)t0 * d, nt, d, K, resp.data(), ws, (t0 == 0) ? T(0) : T(1)); + } + m_finalize(handle, params, n, d, K, ws, weights, means, covariances); + if (ct == covariance_type::FULL) { + for (int t0 = 0; t0 < n; t0 += tile) { + int nt = std::min(tile, n - t0); + fill_init_resp(t0, nt); + m_cov_full_pass(handle, + X + (size_t)t0 * d, + nt, + d, + K, + resp.data(), + means, + covariances, + ws, + (t0 == 0) ? T(0) : T(1)); + } + m_finalize_cov_full(handle, params, d, K, ws, covariances); + } + update_precisions(handle, params, covariances, d, K, precisions_chol, log_det.data(), ws); + } else { + update_precisions(handle, params, covariances, d, K, precisions_chol, log_det.data(), ws); + } + + double lb = -std::numeric_limits::infinity(); + bool conv = false; + int iters = 0; + for (int it = 1; it <= params.max_iter; ++it) { + double prev = lb; + lb = em_step(/* do_mstep */ true); + update_precisions(handle, params, covariances, d, K, precisions_chol, log_det.data(), ws); + iters = it; + if (std::abs(lb - prev) < params.tol) { + conv = true; + break; + } + } + if (params.max_iter == 0) lb = em_step(/* do_mstep */ false); + + if (!have_best || lb > best_lb) { + best_lb = lb; + best_iter = iters; + best_conv = conv; + have_best = true; + raft::copy(best_w.data(), weights, K, stream); + raft::copy(best_m.data(), means, (size_t)K * d, stream); + raft::copy(best_cov.data(), covariances, cn, stream); + raft::copy(best_pc.data(), precisions_chol, cn, stream); + } + } + + // restore best parameters + raft::copy(weights, best_w.data(), K, stream); + raft::copy(means, best_m.data(), (size_t)K * d, stream); + raft::copy(covariances, best_cov.data(), cn, stream); + raft::copy(precisions_chol, best_pc.data(), cn, stream); + + // recompute log_det for the best precisions_chol, derive precisions_, labels + if (ct == covariance_type::FULL) + detail::log_det_full_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else if (ct == covariance_type::TIED) + detail::log_det_tied_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else if (ct == covariance_type::DIAG) + detail::log_det_diag_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else + detail::log_det_spherical_kernel + <<>>(precisions_chol, d, K, log_det.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + compute_precisions(handle, params, precisions_chol, d, K, precisions); + + // Hard labels = argmax responsibility, computed tile-by-tile (no full (N, K)). + for (int t0 = 0; t0 < n; t0 += tile) { + int nt = std::min(tile, n - t0); + e_step(handle, + params, + X + (size_t)t0 * d, + nt, + d, + K, + weights, + means, + precisions_chol, + log_det.data(), + resp.data(), + resp.data(), + lpn.data()); + detail::argmax_kernel + <<>>(resp.data(), nt, K, labels + t0); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + lower_bound = (T)best_lb; + n_iter = best_iter; + converged = best_conv; + raft::resource::sync_stream(handle); +} + +// =========================================================================== +// Predict-family helpers (compute log_det from precisions_chol, run E-step) +// =========================================================================== +template +void infer(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + const T* weights, + const T* means, + const T* precisions_chol, + T* resp, + T* log_prob_norm) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + int K = params.n_components; + covariance_type ct = params.cov_type; + + rmm::device_uvector log_prob((size_t)n * K, stream); + rmm::device_uvector log_det(K, stream); + + if (ct == covariance_type::FULL) + detail::log_det_full_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else if (ct == covariance_type::TIED) + detail::log_det_tied_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else if (ct == covariance_type::DIAG) + detail::log_det_diag_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else + detail::log_det_spherical_kernel + <<>>(precisions_chol, d, K, log_det.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + e_step(handle, + params, + X, + n, + d, + K, + weights, + means, + precisions_chol, + log_det.data(), + log_prob.data(), + resp, + log_prob_norm); +} + +// Launch the register-tiled fused E-step for spherical (DIAG=false) or diag +// (DIAG=true). Components are tiled into CELL register accumulators; features +// stream through a small shared-memory tile -> high occupancy, any d, and the +// (N, K) distance matrix is never materialized. CELL=64 (float) / 32 (double) +// with FEAT=32 is the occupancy sweet spot on Blackwell. +template +void launch_estep_tiled(raft::resources const& handle, + const T* X, + const T* means, + const T* prec_chol, + const T* const_k, + int n, + int d, + int K, + T* log_prob_norm, + int* labels, + T* log_prob = nullptr) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + constexpr int FEAT = 32; + constexpr int CELL = (sizeof(T) == 4) ? 64 : 32; + constexpr int TPB = 256; + int gb = (n + TPB - 1) / TPB; + detail::estep_tiled_kernel<<>>( + X, means, prec_chol, const_k, n, d, K, log_prob_norm, labels, log_prob); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// Fused E-step that writes log_prob_norm (and optionally argmax labels) WITHOUT +// materializing the (N, K) log_prob / responsibility matrix. Mirrors infer()'s +// math for every covariance type, via a per-sample online log-sum-exp: +// - diag / spherical: fast tiled kernel (means in shared, x in registers), +// - tied: transform X̃ = X·U once (shared Cholesky), then it is a +// spherical problem ‖x̃ - μ̃_k‖² -> reuse the tiled kernel, +// - full: per-component center + GEMM (centered @ prec_chol_k) +// folded online (reuses the standard cuBLAS E-step math). +// This is what score_samples / predict use; predict_proba still needs the full +// (N, K) responsibilities so it keeps the standard infer() path. +template +void fused_score(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + const T* weights, + const T* means, + const T* precisions_chol, + T* log_prob_norm, + int* labels) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + int K = params.n_components; + covariance_type ct = params.cov_type; + + rmm::device_uvector log_det(K, stream); + if (ct == covariance_type::FULL) + detail::log_det_full_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else if (ct == covariance_type::TIED) + detail::log_det_tied_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else if (ct == covariance_type::DIAG) + detail::log_det_diag_kernel + <<>>(precisions_chol, d, K, log_det.data()); + else + detail::log_det_spherical_kernel + <<>>(precisions_chol, d, K, log_det.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // Precompute the per-component constant once (avoids log() in the hot loop). + rmm::device_uvector const_k(K, stream); + detail::fused_const_kernel<<>>( + weights, log_det.data(), d, K, const_k.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + int threads = 256; + int rb = (n + threads - 1) / threads; + + if (ct == covariance_type::SPHERICAL) { + launch_estep_tiled( + handle, X, means, precisions_chol, const_k.data(), n, d, K, log_prob_norm, labels); + } else if (ct == covariance_type::DIAG) { + launch_estep_tiled( + handle, X, means, precisions_chol, const_k.data(), n, d, K, log_prob_norm, labels); + } else if (ct == covariance_type::TIED && d <= 128) { + // Shared precision Cholesky U: transform X̃ = X·U and μ̃ = means·U once, + // then the Mahalanobis distance ‖Uᵀ(x-μ_k)‖² becomes the Euclidean + // ‖x̃ - μ̃_k‖² -> reuse the fast register-tiled kernel (prec = 1). + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + rmm::device_uvector xt((size_t)n * d, stream); + rmm::device_uvector mut((size_t)K * d, stream); + rmm::device_uvector ones_pc(K, stream); + thrust::fill(thrust::cuda::par.on(stream), ones_pc.data(), ones_pc.data() + K, T(1)); + T one = T(1), zero = T(0); + // X̃ = X @ U and μ̃ = means @ U (same GEMM convention as the E-step). + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + d, + n, + d, + &one, + precisions_chol, + d, + X, + d, + &zero, + xt.data(), + d), + "gemm(tied_transform_X)"); + cublas_check(cublas_gemm(cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + d, + K, + d, + &one, + precisions_chol, + d, + means, + d, + &zero, + mut.data(), + d), + "gemm(tied_transform_mu)"); + launch_estep_tiled(handle, + xt.data(), + mut.data(), + ones_pc.data(), + const_k.data(), + n, + d, + K, + log_prob_norm, + labels); + } else { + // full (or tied with d > 128): un-centered E-step. y = M_k @ X is formed by + // one GEMM per component, then folded online as ||y - c_k||^2 with the + // projected center c_k = M_k @ mu_k. On this GDDR7 card the path is + // memory-bound; dropping the explicit centering (a full N×d write + read + // per component) cuts the dominant HBM traffic ~1.67x. + int prec_pc = (ct == covariance_type::FULL) ? 1 : 0; + cublasHandle_t cublas = raft::resource::get_cublas_handle(handle); + cublas_check(cublasSetStream(cublas, stream), "cublasSetStream"); + rmm::device_uvector y((size_t)n * d, stream); + rmm::device_uvector c((size_t)K * d, stream); + rmm::device_uvector rmax(n, stream); + rmm::device_uvector rsum(n, stream); + rmm::device_uvector best_lp(labels ? n : 0, stream); + T one = T(1), zero = T(0); + detail::fused_center_proj_kernel<<<(K * d + threads - 1) / threads, threads, 0, stream>>>( + precisions_chol, means, d, K, prec_pc, c.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + detail::fused_lse_init_kernel<<>>( + rmax.data(), rsum.data(), labels ? best_lp.data() : nullptr, labels, n); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + for (int k = 0; k < K; ++k) { + const T* pc_k = precisions_chol + (size_t)(prec_pc ? k : 0) * d * d; + cublas_check( + cublas_gemm( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, d, n, d, &one, pc_k, d, X, d, &zero, y.data(), d), + "gemm(fused_score)"); + detail::fused_fold_from_y_kernel + <<>>(y.data(), + const_k.data(), + c.data(), + n, + d, + k, + rmax.data(), + rsum.data(), + labels ? best_lp.data() : nullptr, + labels); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + detail::fused_lse_finalize_kernel + <<>>(rmax.data(), rsum.data(), n, log_prob_norm); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +template +void predict_impl(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + const T* weights, + const T* means, + const T* precisions_chol, + int* labels) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector lpn(n, stream); + fused_score(handle, params, X, n, d, weights, means, precisions_chol, lpn.data(), labels); + raft::resource::sync_stream(handle); +} + +template +void predict_proba_impl(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + const T* weights, + const T* means, + const T* precisions_chol, + T* resp) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector lpn(n, stream); + infer(handle, params, X, n, d, weights, means, precisions_chol, resp, lpn.data()); + raft::resource::sync_stream(handle); +} + +template +void score_samples_impl(raft::resources const& handle, + const params& params, + const T* X, + int n, + int d, + const T* weights, + const T* means, + const T* precisions_chol, + T* log_prob_norm) +{ + fused_score(handle, params, X, n, d, weights, means, precisions_chol, log_prob_norm, nullptr); + raft::resource::sync_stream(handle); +} + +} // namespace cuvs::cluster::gmm::detail diff --git a/cpp/src/cluster/gmm_kernels.cuh b/cpp/src/cluster/gmm_kernels.cuh new file mode 100644 index 0000000000..5c9cf93ecd --- /dev/null +++ b/cpp/src/cluster/gmm_kernels.cuh @@ -0,0 +1,870 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * CUDA kernels for the Gaussian mixture E- and M-steps, covering all four + * covariance parameterizations (full, tied, diag, spherical). + */ + +#pragma once + +#include +#include + +namespace cuvs::cluster::gmm::detail { + +constexpr float LOG_2PI_F = 1.8378770664093453f; +constexpr double LOG_2PI_D = 1.8378770664093453; + +template +__device__ __forceinline__ T log_2pi_const(); +template <> +__device__ __forceinline__ float log_2pi_const() +{ + return LOG_2PI_F; +} +template <> +__device__ __forceinline__ double log_2pi_const() +{ + return LOG_2PI_D; +} + +__device__ __forceinline__ int upper_tri_col_offset(int col) { return (col * (col + 1)) / 2; } + +// =========================================================================== +// E-step: log Gaussian probability +// =========================================================================== + +// --------------------------------------------------------------------------- +// full / tied, d <= 64. ``prec_per_component`` is 1 for full (prec_chol has a +// d×d matrix per component) and 0 for tied (a single shared d×d matrix). +// --------------------------------------------------------------------------- +template +__global__ void e_step_log_prob_small_kernel(const T* __restrict__ X, + const T* __restrict__ weights, + const T* __restrict__ means, + const T* __restrict__ prec_chol, + const T* __restrict__ log_det, + int n, + int d, + int K, + int prec_per_component, + T* __restrict__ log_prob) +{ + static_assert(D >= 0 && D <= 64, "GMM small E-step supports runtime d or fixed D <= 64"); + constexpr bool fixed_d = D != 0; + int dim = fixed_d ? D : d; + int k = blockIdx.y; + int n_idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x; + int kpc = prec_per_component ? k : 0; + + extern __shared__ unsigned char smem_raw[]; + T* sh_mean = reinterpret_cast(smem_raw); + T* sh_pc = sh_mean + dim; + + for (int i = tid; i < dim; i += blockDim.x) + sh_mean[i] = means[(size_t)k * dim + i]; + int pc_size_dense = dim * dim; + for (int i = tid; i < pc_size_dense; i += blockDim.x) { + int row = i / dim; + int col = i - row * dim; + if (row <= col) { + sh_pc[upper_tri_col_offset(col) + row] = prec_chol[(size_t)kpc * pc_size_dense + i]; + } + } + + __shared__ T sh_const; + if (tid == 0) { sh_const = T(-0.5) * T(dim) * log_2pi_const() + log_det[k] + log(weights[k]); } + + __syncthreads(); + + if (n_idx >= n) return; + + T centered_vals[fixed_d ? D : 64]; + if constexpr (fixed_d) { +#pragma unroll + for (int dd = 0; dd < D; ++dd) + centered_vals[dd] = X[(size_t)n_idx * D + dd] - sh_mean[dd]; + } else { + for (int dd = 0; dd < dim; ++dd) + centered_vals[dd] = X[(size_t)n_idx * dim + dd] - sh_mean[dd]; + } + + T mahal = T(0); + if constexpr (fixed_d) { +#pragma unroll + for (int j = 0; j < D; ++j) { + T y = T(0); + int pc_col = upper_tri_col_offset(j); +#pragma unroll + for (int dd = 0; dd <= j; ++dd) { + y += centered_vals[dd] * sh_pc[pc_col + dd]; + } + mahal += y * y; + } + } else { + for (int j = 0; j < dim; ++j) { + T y = T(0); + int pc_col = upper_tri_col_offset(j); + for (int dd = 0; dd <= j; ++dd) { + y += centered_vals[dd] * sh_pc[pc_col + dd]; + } + mahal += y * y; + } + } + log_prob[(size_t)n_idx * K + k] = sh_const - T(0.5) * mahal; +} + +// --------------------------------------------------------------------------- +// full / tied, d > 64, tiled over 64-column blocks. +// --------------------------------------------------------------------------- +template +__global__ void e_step_log_prob_large_d_thread64_kernel(const T* __restrict__ X, + const T* __restrict__ weights, + const T* __restrict__ means, + const T* __restrict__ prec_chol, + const T* __restrict__ log_det, + int n, + int d, + int K, + int prec_per_component, + T* __restrict__ log_prob) +{ + static_assert(TILE_D == 64, "GMM thread64 E-step expects a 64-column precision tile"); + + int k = blockIdx.y; + int row = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x; + int kpc = prec_per_component ? k : 0; + + extern __shared__ unsigned char smem_raw[]; + T* sh_mean = reinterpret_cast(smem_raw); // (64,) + T* sh_pc = sh_mean + TILE_D; // (64, 64) + + __shared__ T sh_const; + if (tid == 0) { sh_const = T(-0.5) * T(d) * log_2pi_const() + log_det[k] + log(weights[k]); } + + T local_mahal = T(0); + const T* pc = prec_chol + (size_t)kpc * d * d; + + for (int j_base = 0; j_base < d; j_base += TILE_D) { + int cols_in_tile = min(TILE_D, d - j_base); + int dd_limit = min(d, j_base + TILE_D); + T y[TILE_D]; +#pragma unroll + for (int col = 0; col < TILE_D; ++col) + y[col] = T(0); + + for (int dd_base = 0; dd_base < dd_limit; dd_base += TILE_D) { + int feats_in_tile = min(TILE_D, dd_limit - dd_base); + + for (int idx = tid; idx < TILE_D; idx += blockDim.x) { + sh_mean[idx] = (idx < feats_in_tile) ? means[(size_t)k * d + dd_base + idx] : T(0); + } + + constexpr int pc_tile_elems = TILE_D * TILE_D; + for (int idx = tid; idx < pc_tile_elems; idx += blockDim.x) { + int feat = idx / TILE_D; + int col_local = idx - feat * TILE_D; + int dd = dd_base + feat; + int col = j_base + col_local; + T val = T(0); + if (feat < feats_in_tile && col_local < cols_in_tile && dd <= col) { + val = pc[(size_t)dd * d + col]; + } + sh_pc[feat * TILE_D + col_local] = val; + } + + __syncthreads(); + + if (row < n) { +#pragma unroll + for (int feat = 0; feat < TILE_D; ++feat) { + if (feat >= feats_in_tile) break; + T diff = X[(size_t)row * d + dd_base + feat] - sh_mean[feat]; +#pragma unroll + for (int col = 0; col < TILE_D; ++col) { + if (col >= cols_in_tile) break; + y[col] += diff * sh_pc[feat * TILE_D + col]; + } + } + } + + __syncthreads(); + } + + if (row < n) { +#pragma unroll + for (int col = 0; col < TILE_D; ++col) { + if (col >= cols_in_tile) break; + local_mahal += y[col] * y[col]; + } + } + } + + if (row < n) log_prob[(size_t)row * K + k] = sh_const - T(0.5) * local_mahal; +} + +// --------------------------------------------------------------------------- +// cuBLAS E-step (full/tied, wide d): center one component's data, X - means[k]. +// --------------------------------------------------------------------------- +template +__global__ void e_step_center_kernel(const T* __restrict__ X, + const T* __restrict__ means, + int n, + int d, + int k, + T* __restrict__ centered) +{ + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)n * d; + if (idx >= total) return; + int col = idx % d; + centered[idx] = X[idx] - means[(size_t)k * d + col]; +} + +// Projected centers c[k] = M_k @ mu_k for the un-centered E-step, where M_k is +// the (column-major, lda=d) prec-chol matrix used by the cuBLAS GEMM. Lets the +// fused fold compute ||M_k x - c_k||^2 without first materializing (X - mu_k). +template +__global__ void fused_center_proj_kernel(const T* __restrict__ prec_chol, + const T* __restrict__ means, + int d, + int K, + int prec_pc, + T* __restrict__ c) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * d) return; + int k = idx / d; + int i = idx % d; + const T* M = prec_chol + (size_t)(prec_pc ? k : 0) * d * d; + const T* mu = means + (size_t)k * d; + T acc = T(0); + for (int j = 0; j < d; ++j) + acc += M[(size_t)i + (size_t)j * d] * mu[j]; + c[(size_t)k * d + i] = acc; +} + +// cuBLAS E-step: from y = (X - means[k]) @ prec_chol[k] (n, d), compute +// log_prob[:, k] = const_k - 0.5 * sum_j y[:, j]^2 (Kahan-compensated). +template +__global__ void e_step_log_prob_from_y_kernel(const T* __restrict__ y, + const T* __restrict__ weights, + const T* __restrict__ log_det, + int n, + int d, + int K, + int k, + T* __restrict__ log_prob) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + T mahal = T(0); + T compensation = T(0); + for (int col = 0; col < d; ++col) { + T v = y[(size_t)row * d + col]; + T term = v * v - compensation; + T next = mahal + term; + compensation = (next - mahal) - term; + mahal = next; + } + T constant = T(-0.5) * T(d) * log_2pi_const() + log_det[k] + log(weights[k]); + log_prob[(size_t)row * K + k] = constant - T(0.5) * mahal; +} + +// --------------------------------------------------------------------------- +// Per-row log-sum-exp normalize: resp = exp(log_prob - logsumexp); also writes +// the per-sample log-likelihood (= logsumexp) into log_prob_norm. +// --------------------------------------------------------------------------- +template +__global__ void e_step_normalize_kernel( + const T* __restrict__ log_prob, int n, int K, T* __restrict__ resp, T* __restrict__ log_prob_norm) +{ + int n_idx = blockIdx.x; + if (n_idx >= n) return; + int tid = threadIdx.x; + // 64-bit row offset: n_idx * K overflows int32 once n * K exceeds 2^31. + size_t row = (size_t)n_idx * K; + + __shared__ T sh_max; + __shared__ T sh_sum; + + T local_max = -CUDART_INF_F; + for (int k = tid; k < K; k += blockDim.x) { + T v = log_prob[row + k]; + if (v > local_max) local_max = v; + } + for (int off = 16; off > 0; off >>= 1) { + T other = __shfl_down_sync(0xffffffff, local_max, off); + if (other > local_max) local_max = other; + } + if (tid == 0) sh_max = local_max; + __syncthreads(); + T mx = sh_max; + + T local_sum = T(0); + for (int k = tid; k < K; k += blockDim.x) { + local_sum += exp(log_prob[row + k] - mx); + } + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (tid == 0) { + sh_sum = local_sum; + log_prob_norm[n_idx] = log(local_sum) + mx; + } + __syncthreads(); + T log_total = log(sh_sum) + mx; + + for (int k = tid; k < K; k += blockDim.x) { + resp[row + k] = exp(log_prob[row + k] - log_total); + } +} + +// =========================================================================== +// M-step +// =========================================================================== + +// means[k] = num[k] / N_k[k]; weights[k] = N_k[k] / n. +template +__global__ void m_step_finalize_means_kernel(const T* __restrict__ N_k, + const T* __restrict__ num, + T* __restrict__ weights, + T* __restrict__ means, + T eps, + int n, + int d, + int K) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + if (tid == 0) weights[k] = Nk / T(n); + + for (int i = tid; i < d; i += blockDim.x) + means[k * d + i] = num[k * d + i] * inv_Nk; +} + +// sqrt(resp) * (X - means[k]) for one component k. Output feeds a GEMM forming +// the (unnormalized) full covariance. +template +__global__ void weighted_center_kernel(const T* __restrict__ X, + const T* __restrict__ resp, + const T* __restrict__ means, + int n, + int d, + int K, + int k, + T* __restrict__ centered) +{ + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)n * d; + if (idx >= total) return; + + int row = idx / d; + int col = idx - (size_t)row * d; + T r = resp[row * K + k]; + // means == nullptr -> sqrt(r)*x (uncentered moment); else sqrt(r)*(x - mu_k). + T mu = means ? means[(size_t)k * d + col] : T(0); + centered[idx] = sqrt(r) * (X[idx] - mu); +} + +// Finalize one full covariance from the centered moment Σ r·(x-mu)(x-mu)ᵀ: +// divide by N_k, symmetrize the row-major result the column-major GEMM produced, +// and add reg_covar to the diagonal. +template +__global__ void m_step_finalize_cov_full_kernel( + const T* __restrict__ N_k, T* __restrict__ covariances, T reg_covar, T eps, int d, int K) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + int total = d * d; + T* cov = covariances + (size_t)k * d * d; + + for (int idx = tid; idx < total; idx += blockDim.x) { + int i = idx / d; + int j = idx % d; + if (i > j) continue; + + T v = cov[j * d + i] * inv_Nk; + if (i == j) v += reg_covar; + cov[i * d + j] = v; + if (i != j) cov[j * d + i] = v; + } +} + +// Square every element of X into Xsq (used by the diagonal M-step GEMM). +template +__global__ void elementwise_square_kernel(const T* __restrict__ X, + size_t total, + T* __restrict__ out) +{ + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + T v = X[idx]; + out[idx] = v * v; +} + +// diag covariance: var[k,dd] = num2[k,dd]/N_k[k] - means[k,dd]^2 + reg_covar, +// where num2 = respᵀ @ (X*X). +template +__global__ void m_step_finalize_diag_kernel(const T* __restrict__ N_k, + const T* __restrict__ num2, + const T* __restrict__ means, + T reg_covar, + T eps, + int d, + int K, + T* __restrict__ variances) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + for (int dd = tid; dd < d; dd += blockDim.x) { + T mu = means[(size_t)k * d + dd]; + variances[(size_t)k * d + dd] = num2[(size_t)k * d + dd] * inv_Nk - mu * mu + reg_covar; + } +} + +// spherical covariance: var[k] = mean over features of the diagonal variances. +template +__global__ void m_step_spherical_from_diag_kernel(const T* __restrict__ diag_var, + int d, + int K, + T* __restrict__ variances) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + __shared__ T sh[256]; + T local = T(0); + for (int dd = tid; dd < d; dd += blockDim.x) + local += diag_var[(size_t)k * d + dd]; + sh[tid] = local; + __syncthreads(); + for (int off = blockDim.x / 2; off > 0; off >>= 1) { + if (tid < off) sh[tid] += sh[tid + off]; + __syncthreads(); + } + if (tid == 0) variances[k] = sh[0] / T(d); +} + +// Scale each component mean by N_k (for the tied covariance weighted outer +// product (N_k * means)ᵀ @ means). +template +__global__ void scale_rows_by_kernel( + const T* __restrict__ in, const T* __restrict__ scale, T eps, int d, int K, T* __restrict__ out) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + T s = scale[k] + T(10) * eps; + for (int dd = tid; dd < d; dd += blockDim.x) + out[(size_t)k * d + dd] = in[(size_t)k * d + dd] * s; +} + +// tied covariance: cov = (XtX - weighted_means_outer) / sum(N_k) + reg_covar*I. +template +__global__ void m_step_finalize_tied_kernel(const T* __restrict__ XtX, + const T* __restrict__ weighted_outer, + T sum_Nk, + T reg_covar, + int d, + T* __restrict__ covariance) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = d * d; + if (idx >= total) return; + int i = idx / d; + int j = idx % d; + if (i > j) return; // upper triangle drives a symmetric write + T inv = T(1) / sum_Nk; + T v = (XtX[idx] - weighted_outer[idx]) * inv; + if (i == j) v += reg_covar; + covariance[(size_t)i * d + j] = v; + if (i != j) covariance[(size_t)j * d + i] = v; +} + +// =========================================================================== +// Precision Cholesky helpers (diag / spherical) and log-determinants +// =========================================================================== + +// diag/spherical precision Cholesky: 1 / sqrt(variance), element-wise. +template +__global__ void recip_sqrt_kernel(const T* __restrict__ var, size_t total, T* __restrict__ out) +{ + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + out[idx] = T(1) / sqrt(var[idx]); +} + +// log-det of the precision Cholesky for full covariance: +// log_det[k] = sum_i log(prec_chol[k,i,i]). +template +__global__ void log_det_full_kernel(const T* __restrict__ prec_chol, + int d, + int K, + T* __restrict__ log_det) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + __shared__ T sh[256]; + T local = T(0); + const T* pc_k = prec_chol + (size_t)k * d * d; + for (int i = tid; i < d; i += blockDim.x) + local += log(pc_k[(size_t)i * d + i]); + sh[tid] = local; + __syncthreads(); + for (int off = blockDim.x / 2; off > 0; off >>= 1) { + if (tid < off) sh[tid] += sh[tid + off]; + __syncthreads(); + } + if (tid == 0) log_det[k] = sh[0]; +} + +// log-det for tied covariance: a single value broadcast to all K components. +template +__global__ void log_det_tied_kernel(const T* __restrict__ prec_chol, + int d, + int K, + T* __restrict__ log_det) +{ + int tid = threadIdx.x; + __shared__ T sh[256]; + T local = T(0); + for (int i = tid; i < d; i += blockDim.x) + local += log(prec_chol[(size_t)i * d + i]); + sh[tid] = local; + __syncthreads(); + for (int off = blockDim.x / 2; off > 0; off >>= 1) { + if (tid < off) sh[tid] += sh[tid + off]; + __syncthreads(); + } + if (tid == 0) { + for (int k = 0; k < K; ++k) + log_det[k] = sh[0]; + } +} + +// log-det for diagonal covariance: log_det[k] = sum_dd log(prec_chol[k,dd]). +template +__global__ void log_det_diag_kernel(const T* __restrict__ prec_chol, + int d, + int K, + T* __restrict__ log_det) +{ + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + __shared__ T sh[256]; + T local = T(0); + for (int dd = tid; dd < d; dd += blockDim.x) + local += log(prec_chol[(size_t)k * d + dd]); + sh[tid] = local; + __syncthreads(); + for (int off = blockDim.x / 2; off > 0; off >>= 1) { + if (tid < off) sh[tid] += sh[tid + off]; + __syncthreads(); + } + if (tid == 0) log_det[k] = sh[0]; +} + +// log-det for spherical covariance: log_det[k] = d * log(prec_chol[k]). +template +__global__ void log_det_spherical_kernel(const T* __restrict__ prec_chol, + int d, + int K, + T* __restrict__ log_det) +{ + int k = blockIdx.x * blockDim.x + threadIdx.x; + if (k >= K) return; + log_det[k] = T(d) * log(prec_chol[k]); +} + +// =========================================================================== +// Miscellaneous +// =========================================================================== + +// Identity matrix (d×d), row-major. +template +__global__ void set_identity_kernel(T* __restrict__ A, int d) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = d * d; + if (idx >= total) return; + int i = idx / d; + int j = idx % d; + A[idx] = (i == j) ? T(1) : T(0); +} + +// Batched identity: K stacked d×d identity matrices, row-major (K, d, d). +template +__global__ void set_identity_batched_kernel(T* __restrict__ A, int d, int K) +{ + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)K * d * d; + if (idx >= total) return; + size_t within = idx % ((size_t)d * d); + int i = (int)(within / d); + int j = (int)(within % d); + A[idx] = (i == j) ? T(1) : T(0); +} + +// Argmax over components: labels[n] = argmax_k resp[n,k]. +template +__global__ void argmax_kernel(const T* __restrict__ resp, int n, int K, int* __restrict__ labels) +{ + int n_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (n_idx >= n) return; + const T* r = resp + (size_t)n_idx * K; + int best = 0; + T best_v = r[0]; + for (int k = 1; k < K; ++k) { + if (r[k] > best_v) { + best_v = r[k]; + best = k; + } + } + labels[n_idx] = best; +} + +// Scatter hard labels into one-hot responsibilities (n×K). +template +__global__ void labels_to_onehot_kernel(const int* __restrict__ labels, + int n, + int K, + T* __restrict__ resp) +{ + int n_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (n_idx >= n) return; + int lab = labels[n_idx]; + for (int k = 0; k < K; ++k) + resp[(size_t)n_idx * K + k] = (k == lab) ? T(1) : T(0); +} + +// Tile variant: scatter the one-hot for any chosen index falling in [t0, t0+nt) +// into a tile-local (nt, K) buffer (memset to 0 by the caller). +template +__global__ void scatter_onehot_tile_kernel( + const int* __restrict__ indices, int K, int t0, int nt, T* __restrict__ resp_tile) +{ + int k = blockIdx.x * blockDim.x + threadIdx.x; + if (k >= K) return; + int s = indices[k] - t0; + if (s >= 0 && s < nt) resp_tile[(size_t)s * K + k] = T(1); +} + +// Normalize each row of resp to sum to one (used by 'random' init). +template +__global__ void normalize_rows_kernel(T* __restrict__ resp, int n, int K) +{ + int n_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (n_idx >= n) return; + T* r = resp + (size_t)n_idx * K; + T sum = T(0); + for (int k = 0; k < K; ++k) + sum += r[k]; + T inv = T(1) / sum; + for (int k = 0; k < K; ++k) + r[k] *= inv; +} + +// Fast tiled fused E-step (diag / spherical): thread-per-sample with the +// sample's feature vector held in registers and component means (and, for +// diag, reciprocal std-devs) streamed through shared memory in tiles of CELL +// components. Online log-sum-exp -> log_prob_norm (+ optional argmax label), +// no N×K materialization. ``MAXD`` is the register footprint for the feature +// vector (caller picks the smallest bucket >= d); ``DIAG`` selects the +// covariance form at compile time so the shared buffers are sized exactly. +// const_k[k] = -0.5 d log(2pi) + log_det[k] + log(weights[k]), precomputed once +// per component so the hot E-step loop never calls log() per (sample, k). +template +__global__ void fused_const_kernel(const T* __restrict__ weights, + const T* __restrict__ log_det, + int d, + int K, + T* __restrict__ const_k) +{ + int k = blockIdx.x * blockDim.x + threadIdx.x; + if (k >= K) return; + const_k[k] = T(-0.5) * T(d) * log_2pi_const() + log_det[k] + log(weights[k]); +} + +// Initialize the running (max, sum) accumulators for the cuBLAS fold path. +template +__global__ void fused_lse_init_kernel(T* __restrict__ rmax, + T* __restrict__ rsum, + T* __restrict__ best_lp, + int* __restrict__ labels, + int n) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + rmax[i] = -CUDART_INF_F; + rsum[i] = T(0); + if (best_lp) best_lp[i] = -CUDART_INF_F; + if (labels) labels[i] = 0; +} + +// Fold component k (full/tied) into the running log-sum-exp from +// y = (x - mu_k) @ prec_chol_k, with Kahan-compensated ‖y‖² (matches the +// standard cuBLAS E-step). Optionally tracks the argmax component. +template +__global__ void fused_fold_from_y_kernel(const T* __restrict__ y, + const T* __restrict__ const_k, + const T* __restrict__ c, + int n, + int d, + int k, + T* __restrict__ rmax, + T* __restrict__ rsum, + T* __restrict__ best_lp, + int* __restrict__ labels) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + // y = M_k @ x (un-centered). mahal = ||M_k(x - mu_k)||^2 = ||y - c_k||^2 + // with the projected center c_k = M_k @ mu_k precomputed once per component. + const T* ck = c + (size_t)k * d; + T mahal = T(0), comp = T(0); + for (int col = 0; col < d; ++col) { + T v = y[(size_t)row * d + col] - ck[col]; + T term = v * v - comp; + T next = mahal + term; + comp = (next - mahal) - term; + mahal = next; + } + T lp = const_k[k] - T(0.5) * mahal; + if (labels && lp > best_lp[row]) { + best_lp[row] = lp; + labels[row] = k; + } + T pm = rmax[row]; + T nm = fmax(pm, lp); + rsum[row] = rsum[row] * exp(pm - nm) + exp(lp - nm); + rmax[row] = nm; +} + +// log_prob_norm[n] = rmax[n] + log(rsum[n]). +template +__global__ void fused_lse_finalize_kernel(const T* __restrict__ rmax, + const T* __restrict__ rsum, + int n, + T* __restrict__ log_prob_norm) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + log_prob_norm[i] = rmax[i] + log(rsum[i]); +} + +// KDE-style register-tiled E-step for spherical / diag covariances: one thread +// per sample, components tiled in CELL register accumulators, features streamed +// through shared memory in FEAT-wide tiles. The (N, K) distance matrix is never +// materialized — distances fold straight into an online log-sum-exp. Holding +// only CELL accumulators (not all of x) keeps register pressure low -> high +// occupancy (mirrors cuvs::distance::kde's tiling, specialized per component). +// DIAG=false: spherical (scalar prec_chol[k]); DIAG=true: diag (prec_chol[k,:]). +// WRITE_FULL=false: online log-sum-exp -> log_prob_norm (+ argmax labels), no +// N×K. WRITE_FULL=true: also write the full per-component log_prob[n,k] (for the +// materialized fit / predict_proba path, replacing the slow per-mode kernels); +// log_prob_norm/labels are then left to the normalize kernel. +template +__global__ void estep_tiled_kernel(const T* __restrict__ X, + const T* __restrict__ means, + const T* __restrict__ prec_chol, + const T* __restrict__ const_k, + int n, + int d, + int K, + T* __restrict__ log_prob_norm, + int* __restrict__ labels, + T* __restrict__ log_prob) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + bool valid = i < n; + __shared__ T smu[FEAT * CELL]; + __shared__ T spc[DIAG ? FEAT * CELL : 1]; + T rmax = T(-1e30), rsum = T(0), best_lp = T(-1e30); + int best = 0; + for (int k0 = 0; k0 < K; k0 += CELL) { + int cells = min(CELL, K - k0); + T acc[CELL]; +#pragma unroll + for (int c = 0; c < CELL; ++c) + acc[c] = T(0); + for (int f0 = 0; f0 < d; f0 += FEAT) { + int feats = min(FEAT, d - f0); + for (int idx = threadIdx.x; idx < FEAT * CELL; idx += blockDim.x) { + int cell = idx / FEAT, feat = idx % FEAT; + bool in = (cell < cells && feat < feats); + size_t g = (size_t)(k0 + cell) * d + f0 + feat; + smu[feat * CELL + cell] = in ? means[g] : T(0); + if constexpr (DIAG) spc[feat * CELL + cell] = in ? prec_chol[g] : T(0); + } + __syncthreads(); + if (valid) { + for (int f = 0; f < feats; ++f) { + T xq = X[(size_t)i * d + f0 + f]; +#pragma unroll + for (int c = 0; c < CELL; ++c) { + T df = xq - smu[f * CELL + c]; + if constexpr (DIAG) { + T y = df * spc[f * CELL + c]; + acc[c] += y * y; + } else { + acc[c] += df * df; + } + } + } + } + __syncthreads(); + } + if (valid) { +#pragma unroll + for (int c = 0; c < CELL; ++c) { + if (c >= cells) break; + int k = k0 + c; + T mahal; + if constexpr (DIAG) { + mahal = acc[c]; + } else { + T pc = prec_chol[k]; + mahal = (pc * pc) * acc[c]; + } + T lp = const_k[k] - T(0.5) * mahal; + if constexpr (WRITE_FULL) { + log_prob[(size_t)i * K + k] = lp; + } else { + if (labels && lp > best_lp) { + best_lp = lp; + best = k; + } + T nm = fmax(rmax, lp); + rsum = rsum * exp(rmax - nm) + exp(lp - nm); + rmax = nm; + } + } + } + } + if (valid && !WRITE_FULL) { + log_prob_norm[i] = rmax + log(rsum); + if (labels) labels[i] = best; + } +} + +} // namespace cuvs::cluster::gmm::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 13a07b10b5..50993462a6 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -120,7 +120,7 @@ ConfigureTest( ConfigureTest( NAME CLUSTER_TEST PATH cluster/kmeans.cu cluster/kmeans_balanced.cu cluster/kmeans_find_k.cu cluster/linkage.cu - cluster/connect_knn.cu cluster/spectral.cu + cluster/connect_knn.cu cluster/spectral.cu cluster/gmm.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/cluster/gmm.cu b/cpp/tests/cluster/gmm.cu new file mode 100644 index 0000000000..72c574584a --- /dev/null +++ b/cpp/tests/cluster/gmm.cu @@ -0,0 +1,462 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +namespace cuvs::cluster::gmm { + +template +struct GMMInputs { + int n_row; + int n_col; + int n_components; + covariance_type cov_type; +}; + +// Number of elements in a covariance-typed buffer for the given type. +inline int64_t cov_len(covariance_type ct, int d, int K) +{ + switch (ct) { + case covariance_type::FULL: return (int64_t)K * d * d; + case covariance_type::TIED: return (int64_t)d * d; + case covariance_type::DIAG: return (int64_t)K * d; + case covariance_type::SPHERICAL: return (int64_t)K; + } + return 0; +} + +template +class GMMTest : public ::testing::TestWithParam> { + protected: + GMMTest() : stream(raft::resource::get_cuda_stream(handle)) {} + + void basicTest() + { + auto p = ::testing::TestWithParam>::GetParam(); + int n = p.n_row, d = p.n_col, K = p.n_components; + int64_t cn = cov_len(p.cov_type, d, K); + + // Well-separated blobs: hard labels should recover the generating clusters. + auto d_X = raft::make_device_matrix(handle, n, d); + auto d_yref = raft::make_device_vector(handle, n); + raft::random::make_blobs(d_X.data_handle(), + d_yref.data_handle(), + n, + d, + K, + stream, + /* row_major */ true, + /* centers */ nullptr, + /* cluster_std */ nullptr, + /* cluster_std_scalar */ T(1.0), + /* shuffle */ false, + /* center_box_min */ static_cast(-10.0f), + /* center_box_max */ static_cast(10.0f), + /* seed */ 1234ULL); + + auto weights = raft::make_device_vector(handle, K); + auto means = raft::make_device_matrix(handle, K, d); + auto covs = raft::make_device_vector(handle, cn); + auto pchol = raft::make_device_vector(handle, cn); + auto precs = raft::make_device_vector(handle, cn); + auto labels = raft::make_device_vector(handle, n); + + params prm; + prm.n_components = K; + prm.cov_type = p.cov_type; + prm.init = init_method::KMeans; + prm.max_iter = 100; + prm.seed = 1234ULL; + + T lower_bound = 0; + int n_iter = 0; + bool converged = false; + + fit(handle, + prm, + raft::make_const_mdspan(d_X.view()), + weights.view(), + means.view(), + covs.view(), + pchol.view(), + precs.view(), + labels.view(), + raft::make_host_scalar_view(&lower_bound), + raft::make_host_scalar_view(&n_iter), + raft::make_host_scalar_view(&converged)); + + // Fit diagnostics are sane. Well-separated blobs converge well before + // max_iter, so the converged flag must be set and EM must have stopped early. + ASSERT_TRUE(std::isfinite((double)lower_bound)); + ASSERT_GE(n_iter, 1); + ASSERT_TRUE(converged) << "EM did not converge on well-separated blobs"; + ASSERT_LT(n_iter, prm.max_iter) << "converged flag set but EM ran to max_iter"; + + // Hard labels recover the ground-truth clusters on well-separated blobs. + double ari_fit = + raft::stats::adjusted_rand_index(d_yref.data_handle(), labels.data_handle(), n, stream); + raft::resource::sync_stream(handle, stream); + ASSERT_GT(ari_fit, 0.95) << "fit labels disagree with ground truth"; + + // predict() on the same data reproduces the fit labels exactly. + auto labels2 = raft::make_device_vector(handle, n); + predict(handle, + prm, + raft::make_const_mdspan(d_X.view()), + raft::make_const_mdspan(weights.view()), + raft::make_const_mdspan(means.view()), + raft::make_const_mdspan(pchol.view()), + labels2.view()); + double ari_pred = + raft::stats::adjusted_rand_index(labels.data_handle(), labels2.data_handle(), n, stream); + raft::resource::sync_stream(handle, stream); + ASSERT_NEAR(ari_pred, 1.0, 1e-6) << "predict disagrees with fit labels"; + + // predict_proba rows form a valid distribution (sum to 1, non-negative). + auto resp = raft::make_device_matrix(handle, n, K); + predict_proba(handle, + prm, + raft::make_const_mdspan(d_X.view()), + raft::make_const_mdspan(weights.view()), + raft::make_const_mdspan(means.view()), + raft::make_const_mdspan(pchol.view()), + resp.view()); + std::vector h_resp((size_t)n * K); + raft::update_host(h_resp.data(), resp.data_handle(), (size_t)n * K, stream); + + // score_samples: per-sample log-likelihood; its mean equals the lower + // bound returned by fit (both are the average log p(x)). + auto logp = raft::make_device_vector(handle, n); + score_samples(handle, + prm, + raft::make_const_mdspan(d_X.view()), + raft::make_const_mdspan(weights.view()), + raft::make_const_mdspan(means.view()), + raft::make_const_mdspan(pchol.view()), + logp.view()); + std::vector h_logp(n); + raft::update_host(h_logp.data(), logp.data_handle(), n, stream); + raft::resource::sync_stream(handle, stream); + + for (int i = 0; i < n; ++i) { + double s = 0.0; + for (int k = 0; k < K; ++k) { + T r = h_resp[(size_t)i * K + k]; + ASSERT_GE((double)r, -1e-5); + s += (double)r; + } + ASSERT_NEAR(s, 1.0, 1e-3) << "responsibilities row " << i << " not normalized"; + } + + double mean_logp = 0.0; + for (int i = 0; i < n; ++i) + mean_logp += (double)h_logp[i]; + mean_logp /= n; + // tolerance loose for float; lower_bound is the fit-time average log p(x). + double tol = std::is_same_v ? 1e-2 : 1e-5; + ASSERT_NEAR(mean_logp, (double)lower_bound, std::abs((double)lower_bound) * tol + tol); + } + + raft::resources handle; + cudaStream_t stream; +}; + +const std::vector> inputsf = { + {600, 8, 4, covariance_type::FULL}, + {600, 8, 4, covariance_type::TIED}, + {600, 8, 4, covariance_type::DIAG}, + {600, 8, 4, covariance_type::SPHERICAL}, + {2000, 16, 5, covariance_type::FULL}, // fixed-D=16 specialization + {2000, 16, 5, covariance_type::DIAG}, + {2000, 32, 4, covariance_type::FULL}, // fixed-D=32 specialization + {2000, 50, 4, covariance_type::FULL}, // fixed-D=50 specialization + {2000, 64, 4, covariance_type::FULL}, // fixed-D=64 specialization (boundary) + {3000, 128, 4, covariance_type::FULL}, // 64 tiled thread64 kernel + {3000, 128, 4, covariance_type::TIED}, // tied, tiled thread64 kernel + {4000, 300, 4, covariance_type::FULL}, // d>=257 (float) -> cuBLAS E-step route + {2000, 512, 16, covariance_type::DIAG}, // K*d large -> diag global-mem path + {2000, 1024, 16, covariance_type::SPHERICAL}, // K*d large -> spherical global-mem path +}; + +const std::vector> inputsd = { + {600, 8, 4, covariance_type::FULL}, + {600, 8, 4, covariance_type::TIED}, + {600, 8, 4, covariance_type::DIAG}, + {600, 8, 4, covariance_type::SPHERICAL}, + {2000, 16, 5, covariance_type::FULL}, + {2000, 50, 4, covariance_type::FULL}, // fixed-D=50 specialization + {2000, 64, 4, covariance_type::FULL}, // fixed-D=64 specialization (boundary) + {3000, 128, 4, covariance_type::FULL}, // d>64 (double) -> cuBLAS E-step route +}; + +using GMMTestF = GMMTest; +TEST_P(GMMTestF, Result) { basicTest(); } +INSTANTIATE_TEST_CASE_P(GMMTests, GMMTestF, ::testing::ValuesIn(inputsf)); + +using GMMTestD = GMMTest; +TEST_P(GMMTestD, Result) { basicTest(); } +INSTANTIATE_TEST_CASE_P(GMMTests, GMMTestD, ::testing::ValuesIn(inputsd)); + +// --------------------------------------------------------------------------- +// Standalone tests for behaviors not covered by the parametrized sweep: +// every init method, warm_start, n_init best-of, and the ill-defined- +// covariance error path. +// --------------------------------------------------------------------------- +namespace { + +// Generate well-separated blobs into freshly allocated device buffers. +template +std::pair, raft::device_vector> make_gmm_blobs( + raft::resources const& handle, int n, int d, int K, std::uint64_t seed = 1234ULL) +{ + auto X = raft::make_device_matrix(handle, n, d); + auto yref = raft::make_device_vector(handle, n); + raft::random::make_blobs(X.data_handle(), + yref.data_handle(), + n, + d, + K, + raft::resource::get_cuda_stream(handle), + true, + nullptr, + nullptr, + T(1.0), + false, + static_cast(-10.0f), + static_cast(10.0f), + seed); + return {std::move(X), std::move(yref)}; +} + +} // namespace + +// Every init method produces a valid fit; kmeans-family inits recover the +// generating clusters on well-separated blobs. +TEST(GMMExtra, InitMethods) +{ + raft::resources handle; + auto stream = raft::resource::get_cuda_stream(handle); + const int n = 1500, d = 8, K = 4; + auto [X, yref] = make_gmm_blobs(handle, n, d, K); + + for (auto im : {init_method::KMeans, + init_method::KMeansPlusPlus, + init_method::Random, + init_method::RandomFromData}) { + int64_t cn = cov_len(covariance_type::FULL, d, K); + auto weights = raft::make_device_vector(handle, K); + auto means = raft::make_device_matrix(handle, K, d); + auto covs = raft::make_device_vector(handle, cn); + auto pchol = raft::make_device_vector(handle, cn); + auto precs = raft::make_device_vector(handle, cn); + auto labels = raft::make_device_vector(handle, n); + + params prm; + prm.n_components = K; + prm.cov_type = covariance_type::FULL; + prm.init = im; + prm.n_init = 1; + // A modest regularizer keeps every init's first covariance well-defined so + // the test deterministically exercises the init code path (random inits + // can otherwise collapse a component, which legitimately raises). + prm.reg_covar = 1e-2; + prm.max_iter = 100; + prm.seed = 1234ULL; + + float lb = 0; + int it = 0; + bool cv = false; + fit(handle, + prm, + raft::make_const_mdspan(X.view()), + weights.view(), + means.view(), + covs.view(), + pchol.view(), + precs.view(), + labels.view(), + raft::make_host_scalar_view(&lb), + raft::make_host_scalar_view(&it), + raft::make_host_scalar_view(&cv)); + + ASSERT_TRUE(std::isfinite(lb)) << "init " << (int)im; + if (im == init_method::KMeans || im == init_method::KMeansPlusPlus) { + double ari = + raft::stats::adjusted_rand_index(yref.data_handle(), labels.data_handle(), n, stream); + raft::resource::sync_stream(handle, stream); + ASSERT_GT(ari, 0.95) << "init " << (int)im; + } + } +} + +// warm_start reuses the supplied weights/means/covariances as the single +// initialization and refines them to a finite, non-decreasing lower bound. +TEST(GMMExtra, WarmStart) +{ + raft::resources handle; + const int n = 1500, d = 8, K = 4; + auto [X, yref] = make_gmm_blobs(handle, n, d, K); + + int64_t cn = cov_len(covariance_type::FULL, d, K); + auto weights = raft::make_device_vector(handle, K); + auto means = raft::make_device_matrix(handle, K, d); + auto covs = raft::make_device_vector(handle, cn); + auto pchol = raft::make_device_vector(handle, cn); + auto precs = raft::make_device_vector(handle, cn); + auto labels = raft::make_device_vector(handle, n); + + params prm; + prm.n_components = K; + prm.cov_type = covariance_type::FULL; + prm.init = init_method::KMeans; + prm.max_iter = 5; + prm.seed = 1234ULL; + + double lb1 = 0; + int it1 = 0; + bool cv1 = false; + auto run = [&](bool warm, double& lb, int& it, bool& cv) { + fit(handle, + prm, + raft::make_const_mdspan(X.view()), + weights.view(), + means.view(), + covs.view(), + pchol.view(), + precs.view(), + labels.view(), + raft::make_host_scalar_view(&lb), + raft::make_host_scalar_view(&it), + raft::make_host_scalar_view(&cv), + warm); + }; + run(false, lb1, it1, cv1); + + // Continue from the fitted parameters; the lower bound should not regress. + double lb2 = 0; + int it2 = 0; + bool cv2 = false; + prm.max_iter = 20; + run(true, lb2, it2, cv2); + ASSERT_TRUE(std::isfinite(lb2)); + ASSERT_GE(lb2, lb1 - 1e-6); +} + +// n_init>1 keeps the restart with the largest lower bound. The first restart +// of an N-restart fit uses the same seed as a single-restart fit, so more +// restarts can only match or beat it: lower_bound(n_init=N) >= lower_bound(1). +TEST(GMMExtra, NInitSelectsBest) +{ + raft::resources handle; + const int n = 1500, d = 8, K = 4; + auto [X, yref] = make_gmm_blobs(handle, n, d, K); + + int64_t cn = cov_len(covariance_type::FULL, d, K); + auto weights = raft::make_device_vector(handle, K); + auto means = raft::make_device_matrix(handle, K, d); + auto covs = raft::make_device_vector(handle, cn); + auto pchol = raft::make_device_vector(handle, cn); + auto precs = raft::make_device_vector(handle, cn); + auto labels = raft::make_device_vector(handle, n); + + auto run = [&](int n_init) { + params prm; + prm.n_components = K; + prm.cov_type = covariance_type::FULL; + prm.init = init_method::Random; // restart-sensitive init + prm.n_init = n_init; + prm.reg_covar = 1e-2; // keep every restart well-defined + prm.max_iter = 50; + prm.seed = 1234ULL; + float lb = 0; + int it = 0; + bool cv = false; + fit(handle, + prm, + raft::make_const_mdspan(X.view()), + weights.view(), + means.view(), + covs.view(), + pchol.view(), + precs.view(), + labels.view(), + raft::make_host_scalar_view(&lb), + raft::make_host_scalar_view(&it), + raft::make_host_scalar_view(&cv)); + return lb; + }; + + float lb1 = run(1); + float lb10 = run(10); + ASSERT_GE((double)lb10, (double)lb1 - 1e-5); +} + +// A degenerate component (more components than distinct points) yields an +// ill-defined covariance and must surface as an exception rather than NaNs. +TEST(GMMExtra, IllDefinedCovarianceThrows) +{ + raft::resources handle; + auto stream = raft::resource::get_cuda_stream(handle); + const int n = 6, d = 4, K = 5; + + // All points identical -> any component covariance collapses to zero. + auto X = raft::make_device_matrix(handle, n, d); + RAFT_CUDA_TRY(cudaMemsetAsync(X.data_handle(), 0, sizeof(float) * (size_t)n * d, stream)); + + int64_t cn = cov_len(covariance_type::FULL, d, K); + auto weights = raft::make_device_vector(handle, K); + auto means = raft::make_device_matrix(handle, K, d); + auto covs = raft::make_device_vector(handle, cn); + auto pchol = raft::make_device_vector(handle, cn); + auto precs = raft::make_device_vector(handle, cn); + auto labels = raft::make_device_vector(handle, n); + + params prm; + prm.n_components = K; + prm.cov_type = covariance_type::FULL; + prm.init = init_method::RandomFromData; + prm.reg_covar = 0.0; // disable the regularizer that would otherwise mask it + prm.max_iter = 50; + prm.seed = 1234ULL; + + float lb = 0; + int it = 0; + bool cv = false; + EXPECT_ANY_THROW(fit(handle, + prm, + raft::make_const_mdspan(X.view()), + weights.view(), + means.view(), + covs.view(), + pchol.view(), + precs.view(), + labels.view(), + raft::make_host_scalar_view(&lb), + raft::make_host_scalar_view(&it), + raft::make_host_scalar_view(&cv))); +} + +} // namespace cuvs::cluster::gmm diff --git a/fern/docs.yml b/fern/docs.yml index 882dd665aa..d0b866c33b 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -152,6 +152,8 @@ navigation: contents: - page: "K-Means" path: "./pages/cluster/kmeans.md" + - page: "Gaussian Mixture Model" + path: "./pages/cluster/gmm.md" - page: "Single-linkage" path: "./pages/cluster/single_linkage.md" - page: "Spectral Clustering" @@ -263,6 +265,8 @@ navigation: - section: "C API Documentation" path: "./pages/c_api/index.md" contents: + - page: "Cluster Gmm" + path: "./pages/c_api/c-api-cluster-gmm.md" - page: "Cluster Kmeans" path: "./pages/c_api/c-api-cluster-kmeans.md" - page: "Core C API" @@ -316,6 +320,8 @@ navigation: contents: - page: "Cluster Agglomerative" path: "./pages/cpp_api/cpp-api-cluster-agglomerative.md" + - page: "Cluster Gmm" + path: "./pages/cpp_api/cpp-api-cluster-gmm.md" - page: "Cluster Kmeans" path: "./pages/cpp_api/cpp-api-cluster-kmeans.md" - page: "Cluster Spectral" @@ -404,6 +410,8 @@ navigation: - section: "Python API Documentation" path: "./pages/python_api/index.md" contents: + - page: "Cluster Gmm" + path: "./pages/python_api/python-api-cluster-gmm.md" - page: "Cluster Kmeans" path: "./pages/python_api/python-api-cluster-kmeans.md" - page: "Common" diff --git a/fern/pages/c_api/c-api-cluster-gmm.md b/fern/pages/c_api/c-api-cluster-gmm.md new file mode 100644 index 0000000000..c767280ec5 --- /dev/null +++ b/fern/pages/c_api/c-api-cluster-gmm.md @@ -0,0 +1,266 @@ +--- +slug: api-reference/c-api-cluster-gmm +--- + +# Gmm + +_Source header: `cuvs/cluster/gmm.h`_ + +## Gaussian mixture hyperparameters + + +### cuvsGMMCovarianceType + +Covariance parameterization of the mixture components. + +```c +typedef enum { + CUVS_GMM_COVARIANCE_FULL = 0, + CUVS_GMM_COVARIANCE_TIED = 1, + CUVS_GMM_COVARIANCE_DIAG = 2, + CUVS_GMM_COVARIANCE_SPHERICAL = 3 +} cuvsGMMCovarianceType; +``` + +**Values** + +| Name | Value | +| --- | --- | +| `CUVS_GMM_COVARIANCE_FULL` | `0` | +| `CUVS_GMM_COVARIANCE_TIED` | `1` | +| `CUVS_GMM_COVARIANCE_DIAG` | `2` | +| `CUVS_GMM_COVARIANCE_SPHERICAL` | `3` | + + +### cuvsGMMInitMethod + +Strategy used to initialize the responsibilities before EM. + +```c +typedef enum { + CUVS_GMM_INIT_KMEANS = 0, + CUVS_GMM_INIT_KMEANS_PLUS_PLUS = 1, + CUVS_GMM_INIT_RANDOM = 2, + CUVS_GMM_INIT_RANDOM_FROM_DATA = 3 +} cuvsGMMInitMethod; +``` + +**Values** + +| Name | Value | +| --- | --- | +| `CUVS_GMM_INIT_KMEANS` | `0` | +| `CUVS_GMM_INIT_KMEANS_PLUS_PLUS` | `1` | +| `CUVS_GMM_INIT_RANDOM` | `2` | +| `CUVS_GMM_INIT_RANDOM_FROM_DATA` | `3` | + + +### cuvsGMMParams + +Hyper-parameters for the Gaussian mixture EM solver + +```c +struct cuvsGMMParams { + int n_components; + cuvsGMMCovarianceType covariance_type; + double tol; + double reg_covar; + int max_iter; + int n_init; + cuvsGMMInitMethod init; + uint64_t seed; +}; +``` + +**Fields** + +| Name | Type | Description | +| --- | --- | --- | +| `n_components` | `int` | The number of mixture components. Default: 1. | +| `covariance_type` | [`cuvsGMMCovarianceType`](/api-reference/c-api-cluster-gmm#cuvsgmmcovariancetype) | Covariance parameterization of the mixture components. Default: FULL. | +| `tol` | `double` | Convergence threshold on the change of the per-sample average log-likelihood (lower bound). Default: 1e-3. | +| `reg_covar` | `double` | Non-negative regularization added to the diagonal of covariance.
Default: 1e-6. | +| `max_iter` | `int` | Maximum number of EM iterations for a single run. Default: 100. | +| `n_init` | `int` | Number of initializations to perform; the best result is kept. Default: 1. | +| `init` | [`cuvsGMMInitMethod`](/api-reference/c-api-cluster-gmm#cuvsgmminitmethod) | Strategy used to initialize the responsibilities before EM.
Default: KMEANS. | +| `seed` | `uint64_t` | Seed to the random number generator. Default: 0. | + + +### cuvsGMMParamsCreate + +Allocate GMM params, and populate with default values + +```c +cuvsError_t cuvsGMMParamsCreate(cuvsGMMParams_t* params); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `params` | in | [`cuvsGMMParams_t*`](/api-reference/c-api-cluster-gmm#cuvsgmmparams) | cuvsGMMParams_t to allocate | + +**Returns** + +[`cuvsError_t`](/api-reference/c-api-core-c-api#cuvserror-t) + + +### cuvsGMMParamsDestroy + +De-allocate GMM params + +```c +cuvsError_t cuvsGMMParamsDestroy(cuvsGMMParams_t params); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `params` | in | [`cuvsGMMParams_t`](/api-reference/c-api-cluster-gmm#cuvsgmmparams) | | + +**Returns** + +[`cuvsError_t`](/api-reference/c-api-core-c-api#cuvserror-t) + +## Gaussian mixture model APIs + + +### cuvsGMMFit + +Fit a Gaussian mixture with the EM algorithm. + +```c +cuvsError_t cuvsGMMFit(cuvsResources_t res, +cuvsGMMParams_t params, +DLManagedTensor* X, +DLManagedTensor* weights, +DLManagedTensor* means, +DLManagedTensor* covariances, +DLManagedTensor* precisions_chol, +DLManagedTensor* precisions, +DLManagedTensor* labels, +double* lower_bound, +int* n_iter, +bool* converged, +bool warm_start); +``` + +Runs ``params->n_init`` random restarts (unless ``warm_start`` is true) and keeps the parameters with the largest lower bound. + +All tensors must reside on device memory and be row-major. ``X``, ``weights``, ``means``, ``covariances``, ``precisions_chol`` and ``precisions`` must share one dtype (float32 or float64); ``labels`` is int32. + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `res` | in | [`cuvsResources_t`](/api-reference/c-api-core-c-api#cuvsresources-t) | opaque C handle | +| `params` | in | [`cuvsGMMParams_t`](/api-reference/c-api-cluster-gmm#cuvsgmmparams) | Parameters for the GMM model. | +| `X` | in | `DLManagedTensor*` | Training data. [dim = n_samples x n_features] | +| `weights` | inout | `DLManagedTensor*` | Mixture weights. [len = n_components] | +| `means` | inout | `DLManagedTensor*` | Component means. [dim = n_components x n_features] | +| `covariances` | inout | `DLManagedTensor*` | Component covariances, flat. Length by covariance_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `precisions_chol` | out | `DLManagedTensor*` | Precision Cholesky factors, same flat layout as covariances (FULL/TIED: upper-triangular factor U with precision = U @ Uᵀ; DIAG/SPHERICAL: reciprocal standard deviations). | +| `precisions` | out | `DLManagedTensor*` | Precision matrices, same flat layout as covariances. | +| `labels` | out | `DLManagedTensor*` | Hard component assignment per sample. [len = n_samples] | +| `lower_bound` | out | `double*` | Per-sample average log-likelihood of the best fit. | +| `n_iter` | out | `int*` | Number of EM iterations of the best fit. | +| `converged` | out | `bool*` | Whether the best fit converged within tol. | +| `warm_start` | in | `bool` | Use the incoming weights/means/covariances as the single initialization. | + +**Returns** + +[`cuvsError_t`](/api-reference/c-api-core-c-api#cuvserror-t) + + +### cuvsGMMPredict + +Hard component labels (argmax responsibility) for new data. + +```c +cuvsError_t cuvsGMMPredict(cuvsResources_t res, +cuvsGMMParams_t params, +DLManagedTensor* X, +DLManagedTensor* weights, +DLManagedTensor* means, +DLManagedTensor* precisions_chol, +DLManagedTensor* labels); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `res` | in | [`cuvsResources_t`](/api-reference/c-api-core-c-api#cuvsresources-t) | opaque C handle | +| `params` | in | [`cuvsGMMParams_t`](/api-reference/c-api-cluster-gmm#cuvsgmmparams) | Parameters used to fit the GMM model. | +| `X` | in | `DLManagedTensor*` | Data to assign. [dim = n_samples x n_features] | +| `weights` | in | `DLManagedTensor*` | Fitted mixture weights. [len = n_components] | +| `means` | in | `DLManagedTensor*` | Fitted component means. [dim = n_components x n_features] | +| `precisions_chol` | in | `DLManagedTensor*` | Fitted precision Cholesky factors, flat. Length by covariance_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `labels` | out | `DLManagedTensor*` | Hard component assignment per sample (int32). [len = n_samples] | + +**Returns** + +[`cuvsError_t`](/api-reference/c-api-core-c-api#cuvserror-t) + + +### cuvsGMMPredictProba + +Posterior responsibilities for new data. + +```c +cuvsError_t cuvsGMMPredictProba(cuvsResources_t res, +cuvsGMMParams_t params, +DLManagedTensor* X, +DLManagedTensor* weights, +DLManagedTensor* means, +DLManagedTensor* precisions_chol, +DLManagedTensor* resp); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `res` | in | [`cuvsResources_t`](/api-reference/c-api-core-c-api#cuvsresources-t) | opaque C handle | +| `params` | in | [`cuvsGMMParams_t`](/api-reference/c-api-cluster-gmm#cuvsgmmparams) | Parameters used to fit the GMM model. | +| `X` | in | `DLManagedTensor*` | Data to evaluate. [dim = n_samples x n_features] | +| `weights` | in | `DLManagedTensor*` | Fitted mixture weights. [len = n_components] | +| `means` | in | `DLManagedTensor*` | Fitted component means. [dim = n_components x n_features] | +| `precisions_chol` | in | `DLManagedTensor*` | Fitted precision Cholesky factors, flat. Length by covariance_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `resp` | out | `DLManagedTensor*` | Posterior probability of each component for each sample. [dim = n_samples x n_components] | + +**Returns** + +[`cuvsError_t`](/api-reference/c-api-core-c-api#cuvserror-t) + + +### cuvsGMMScoreSamples + +Per-sample log-likelihood log p(x_i) for new data. + +```c +cuvsError_t cuvsGMMScoreSamples(cuvsResources_t res, +cuvsGMMParams_t params, +DLManagedTensor* X, +DLManagedTensor* weights, +DLManagedTensor* means, +DLManagedTensor* precisions_chol, +DLManagedTensor* log_prob_norm); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `res` | in | [`cuvsResources_t`](/api-reference/c-api-core-c-api#cuvsresources-t) | opaque C handle | +| `params` | in | [`cuvsGMMParams_t`](/api-reference/c-api-cluster-gmm#cuvsgmmparams) | Parameters used to fit the GMM model. | +| `X` | in | `DLManagedTensor*` | Data to evaluate. [dim = n_samples x n_features] | +| `weights` | in | `DLManagedTensor*` | Fitted mixture weights. [len = n_components] | +| `means` | in | `DLManagedTensor*` | Fitted component means. [dim = n_components x n_features] | +| `precisions_chol` | in | `DLManagedTensor*` | Fitted precision Cholesky factors, flat. Length by covariance_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `log_prob_norm` | out | `DLManagedTensor*` | Log-likelihood of each sample under the model. [len = n_samples] | + +**Returns** + +[`cuvsError_t`](/api-reference/c-api-core-c-api#cuvserror-t) diff --git a/fern/pages/c_api/index.md b/fern/pages/c_api/index.md index 8c72554709..619059a75e 100644 --- a/fern/pages/c_api/index.md +++ b/fern/pages/c_api/index.md @@ -4,6 +4,7 @@ These pages are generated from the documented public headers in the cuVS source ## Cluster +- [Gmm](/api-reference/c-api-cluster-gmm) - [K-Means](/api-reference/c-api-cluster-kmeans) ## Common diff --git a/fern/pages/cluster/gmm.md b/fern/pages/cluster/gmm.md new file mode 100644 index 0000000000..703f0ab756 --- /dev/null +++ b/fern/pages/cluster/gmm.md @@ -0,0 +1,178 @@ +# Gaussian Mixture Model + +A Gaussian Mixture Model (GMM) is a GPU-accelerated probabilistic clustering and density-estimation algorithm. It models a dataset as a weighted sum of `n_components` Gaussian components, learning a weight, mean, and covariance for each one with the Expectation-Maximization (EM) algorithm. + +Use a GMM when you want soft cluster assignments (a probability that each row belongs to each component), a generative density model you can score new points against, or clusters that are elliptical rather than spherical. Unlike K-Means, which assigns every row to exactly one centroid, a GMM returns a full responsibility distribution over components and captures per-component shape through its covariance. Its primary outputs are the component weights, means, covariances (and their Cholesky-factored precisions), per-row labels, and the converged log-likelihood lower bound. + +## Example API Usage + +[C API](/api-reference/c-api-cluster-gmm) | [C++ API](/api-reference/cpp-api-cluster-gmm) | [Python API](/api-reference/python-api-cluster-gmm) + +### Fitting a mixture + +Fitting learns the component weights, means, and covariances from a dataset on the device. The covariance-shaped outputs (`covariances`, `precisions_chol`, `precisions`) have a layout that depends on `covariance_type`; see the API reference for the exact shapes. + + + + +```c +#include +#include + +cuvsResources_t res; +cuvsGMMParams_t params; +DLManagedTensor *dataset; +DLManagedTensor *weights, *means, *covariances, *precisions_chol, *precisions, *labels; +double lower_bound; +int n_iter; +bool converged; + +load_dataset(dataset); +allocate_outputs(weights, means, covariances, precisions_chol, precisions, labels); + +cuvsResourcesCreate(&res); +cuvsGMMParamsCreate(¶ms); + +params->n_components = 1024; +params->covariance_type = FULL; +params->max_iter = 100; +params->tol = 1e-3; + +cuvsGMMFit(res, params, dataset, weights, means, covariances, precisions_chol, + precisions, labels, &lower_bound, &n_iter, &converged, false); + +cuvsGMMParamsDestroy(params); +cuvsResourcesDestroy(res); +``` + + + + +```cpp +#include + +#include +#include + +using namespace cuvs::cluster; + +raft::device_resources res; +raft::device_matrix_view dataset = load_dataset(); + +gmm::params params; +params.n_components = 1024; +params.cov_type = gmm::covariance_type::FULL; +params.max_iter = 100; +params.tol = 1e-3; + +// Output buffers; covariance-shaped extents follow params.cov_type. +auto [weights, means, covariances, precisions_chol, precisions, labels] = + allocate_outputs(res, params, dataset); + +float lower_bound; +int n_iter; +bool converged; + +gmm::fit(res, + params, + dataset, + weights.view(), + means.view(), + covariances.view(), + precisions_chol.view(), + precisions.view(), + labels.view(), + raft::make_host_scalar_view(&lower_bound), + raft::make_host_scalar_view(&n_iter), + raft::make_host_scalar_view(&converged)); +``` + + + + +```python +import cupy as cp + +from cuvs.cluster.gmm import GMMParams, fit + +dataset = cp.asarray(load_dataset(), dtype=cp.float32) +params = GMMParams(n_components=1024, covariance_type="full", max_iter=100, tol=1e-3) + +out = fit(params, dataset) +# out.weights, out.means, out.covariances, out.precisions_chol, out.labels, ... +``` + + + + +### Assigning labels and scoring + +After fitting, reuse the learned `weights`, `means`, and `precisions_chol` to assign hard labels (`predict`), produce per-component responsibilities (`predict_proba`), or evaluate the per-row log-likelihood of new data (`score_samples`). + + + + +```python +from cuvs.cluster.gmm import predict, predict_proba, score_samples + +out = fit(params, dataset) + +# Hard label per row (argmax responsibility). +labels = predict(params, dataset, out.weights, out.means, out.precisions_chol) + +# (n_samples, n_components) responsibility matrix. +resp = predict_proba(params, dataset, out.weights, out.means, out.precisions_chol) + +# Per-row log-likelihood under the fitted mixture. +log_prob = score_samples(params, dataset, out.weights, out.means, out.precisions_chol) +``` + + + + +The same three calls are available in C (`cuvsGMMPredict`, `cuvsGMMPredictProba`, `cuvsGMMScoreSamples`) and C++ (`gmm::predict`, `gmm::predict_proba`, `gmm::score_samples`). + +## How GMM works + +EM alternates between two steps until the average log-likelihood stops improving: + +1. **E-step**: given the current parameters, compute the responsibility of each component for each row — the posterior probability that the row was generated by that component. +2. **M-step**: given the responsibilities, update each component's weight, mean, and covariance to the responsibility-weighted statistics of the data. + +The algorithm repeats until it reaches `max_iter` or the per-sample average log-likelihood changes by less than `tol`. Because both steps reduce to dense linear algebra over many rows and components, the GPU is well suited to the work. + +## Covariance types + +`covariance_type` controls how much shape each component can express, trading flexibility for parameters and cost: + +| Type | Description | +| --- | --- | +| `full` | Each component has its own full covariance matrix. Most flexible, most expensive. | +| `tied` | All components share a single full covariance matrix. | +| `diag` | Each component has its own diagonal covariance (axis-aligned ellipsoids). | +| `spherical` | Each component has a single variance (isotropic). Fewest parameters, fastest. | + +## When to use + +Use a GMM when soft, probabilistic assignments matter, when components are elliptical or have different shapes, or when you need a density model to score or compare new points. Prefer `full` or `diag` covariances when component shape is informative, and `spherical` or `tied` when data is limited or speed matters more than per-component shape. If you only need hard, roughly spherical partitions, K-Means is simpler and faster. + +## Configuration parameters + +| Parameter | Default | Description | +| --- | --- | --- | +| `n_components` | `1` | Number of mixture components. Larger values fit finer structure but increase work and parameter memory. | +| `covariance_type` | `full` | Covariance parameterization (`full`, `tied`, `diag`, `spherical`). | +| `tol` | `1e-3` | Convergence threshold on the change of the per-sample average log-likelihood. | +| `reg_covar` | `1e-6` | Non-negative regularization added to the covariance diagonal for numerical stability. | +| `max_iter` | `100` | Maximum number of EM iterations for one run. | +| `n_init` | `1` | Number of independent runs with different seeds; the best result is kept. | +| `init_method` | `kmeans` | Responsibility initialization: `kmeans`, `kmeans++`, `random`, or `random_from_data`. | +| `seed` | `0` | Seed for the random number generator. | + +## Tuning + +Start with `n_components` and `covariance_type`. More components and richer covariances capture more structure but cost more memory and time, and can overfit when data is limited; raise `reg_covar` if covariances become ill-conditioned. Use `kmeans` initialization for robust default seeding, and increase `n_init` when different seeds produce noticeably different log-likelihoods. Tune `max_iter` and `tol` together: if `n_iter` regularly reaches `max_iter`, increase `max_iter` or relax `tol`. + +## Memory footprint + +Fitting streams the E and M steps over tiles of rows, so it never materializes the full `(n_samples, n_components)` responsibility matrix — peak device memory stays bounded by the input data, the model parameters, and one responsibility tile, independent of `n_samples`. `predict` and `score_samples` likewise avoid the full responsibility matrix; only `predict_proba` materializes the `(n_samples, n_components)` output because that matrix is its result. diff --git a/fern/pages/cluster/index.md b/fern/pages/cluster/index.md index 387dccd93b..d265ad6e4f 100644 --- a/fern/pages/cluster/index.md +++ b/fern/pages/cluster/index.md @@ -7,5 +7,6 @@ slug: user-guide/api-guides/clustering-guide Use these guides for NVIDIA cuVS clustering APIs that group related vectors or build graph structures from pairwise relationships. - [K-Means](/user-guide/api-guides/clustering-guide/k-means): partition vectors into a fixed number of clusters, often for scalable vector-search partitioning. +- [Gaussian Mixture Model](/user-guide/api-guides/clustering-guide/gaussian-mixture-model): fit a probabilistic mixture of Gaussians for soft, shape-aware clustering and density estimation. - [Single-linkage](/user-guide/api-guides/clustering-guide/single-linkage): build hierarchical clusters from nearest-neighbor relationships. - [Spectral Clustering](/user-guide/api-guides/clustering-guide/spectral-clustering): use graph structure and spectral methods to find clusters with more complex shapes. diff --git a/fern/pages/cpp_api/cpp-api-cluster-gmm.md b/fern/pages/cpp_api/cpp-api-cluster-gmm.md new file mode 100644 index 0000000000..625a4801c6 --- /dev/null +++ b/fern/pages/cpp_api/cpp-api-cluster-gmm.md @@ -0,0 +1,352 @@ +--- +slug: api-reference/cpp-api-cluster-gmm +--- + +# Gmm + +_Source header: `cuvs/cluster/gmm.hpp`_ + +## Gaussian mixture hyperparameters + + +### cluster::gmm::covariance_type + +Covariance parameterization of the mixture components. + +```cpp +enum class covariance_type { + FULL = 0, + TIED = 1, + DIAG = 2, + SPHERICAL = 3 +}; +``` + +**Values** + +| Name | Value | +| --- | --- | +| `FULL` | `0` | +| `TIED` | `1` | +| `DIAG` | `2` | +| `SPHERICAL` | `3` | + + +### cluster::gmm::init_method + +Strategy used to initialize the responsibilities before EM. + +```cpp +enum class init_method { + KMeans = 0, + KMeansPlusPlus = 1, + Random = 2, + RandomFromData = 3 +}; +``` + +**Values** + +| Name | Value | +| --- | --- | +| `KMeans` | `0` | +| `KMeansPlusPlus` | `1` | +| `Random` | `2` | +| `RandomFromData` | `3` | + + +### cluster::gmm::params + +Hyper-parameters for the Gaussian mixture EM solver. + +```cpp +struct params { + int n_components; + covariance_type cov_type; + double tol; + double reg_covar; + int max_iter; + int n_init; + init_method init; + uint64_t seed; +}; +``` + +**Fields** + +| Name | Type | Description | +| --- | --- | --- | +| `n_components` | `int` | The number of mixture components. Default: 1. | +| `cov_type` | [`covariance_type`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-covariance-type) | Covariance parameterization of the mixture components. Default: FULL. | +| `tol` | `double` | Convergence threshold on the change of the per-sample average log-likelihood (lower bound). Default: 1e-3. | +| `reg_covar` | `double` | Non-negative regularization added to the diagonal of covariance.
Default: 1e-6. | +| `max_iter` | `int` | Maximum number of EM iterations for a single run. Default: 100. | +| `n_init` | `int` | Number of initializations to perform; the best result is kept.
Default: 1. | +| `init` | [`init_method`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-init-method) | Strategy used to initialize the responsibilities before EM.
Default: KMeans. | +| `seed` | `uint64_t` | Seed to the random number generator. Default: 0. | + +## Gaussian mixture model APIs + + +### cluster::gmm::fit + +Fit a Gaussian mixture with the EM algorithm. + +```cpp +void fit(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view covariances, +raft::device_vector_view precisions_chol, +raft::device_vector_view precisions, +raft::device_vector_view labels, +raft::host_scalar_view lower_bound, +raft::host_scalar_view n_iter, +raft::host_scalar_view converged, +bool warm_start = false); +``` + +Runs ``params.n_init`` random restarts (unless `warm_start` is true) and keeps the parameters with the largest lower bound. Writes the fitted ``weights``, ``means``, ``covariances``, ``precisions_chol`` and ``precisions``, the per-sample hard ``labels`` (argmax of the final responsibilities), and the scalar ``lower_bound`` / ``n_iter`` / ``converged`` diagnostics. + +When `warm_start` is true the incoming ``weights`` / ``means`` / ``covariances`` are used as the single initialization and ``params.n_init`` is ignored. + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | in | `raft::resources const&` | The raft resources handle. | +| `params` | in | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | Hyper-parameters of the EM solver. | +| `X` | in | `raft::device_matrix_view` | Training data, row-major. [dim = n_samples x n_features] | +| `weights` | inout | `raft::device_vector_view` | Mixture weights. [len = n_components] | +| `means` | inout | `raft::device_matrix_view` | Component means, row-major. [dim = n_components x n_features] | +| `covariances` | inout | `raft::device_vector_view` | Component covariances, flat. Length depends on cov_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `precisions_chol` | out | `raft::device_vector_view` | Precision Cholesky factors, same flat layout as covariances. FULL/TIED hold the upper-triangular factor U (precision = U @ Uᵀ); DIAG/SPHERICAL hold reciprocal standard deviations. | +| `precisions` | out | `raft::device_vector_view` | Precision matrices, same flat layout as covariances. | +| `labels` | out | `raft::device_vector_view` | Hard component assignment per sample. [len = n_samples] | +| `lower_bound` | out | `raft::host_scalar_view` | Per-sample average log-likelihood of the best fit. | +| `n_iter` | out | `raft::host_scalar_view` | Number of EM iterations of the best fit. | +| `converged` | out | `raft::host_scalar_view` | Whether the best fit converged within ``params.tol``. | +| `warm_start` | in | `bool` | Use the incoming weights/means/covariances as the single initialization.
Default: `false`. | + +**Returns** + +`void` + +**Additional overload:** `cluster::gmm::fit` + +```cpp +void fit(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view covariances, +raft::device_vector_view precisions_chol, +raft::device_vector_view precisions, +raft::device_vector_view labels, +raft::host_scalar_view lower_bound, +raft::host_scalar_view n_iter, +raft::host_scalar_view converged, +bool warm_start = false); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | | `raft::resources const&` | | +| `params` | | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | | +| `X` | | `raft::device_matrix_view` | | +| `weights` | | `raft::device_vector_view` | | +| `means` | | `raft::device_matrix_view` | | +| `covariances` | | `raft::device_vector_view` | | +| `precisions_chol` | | `raft::device_vector_view` | | +| `precisions` | | `raft::device_vector_view` | | +| `labels` | | `raft::device_vector_view` | | +| `lower_bound` | | `raft::host_scalar_view` | | +| `n_iter` | | `raft::host_scalar_view` | | +| `converged` | | `raft::host_scalar_view` | | +| `warm_start` | | `bool` | Default: `false`. | + +**Returns** + +`void` + + +### cluster::gmm::predict + +Hard component labels (argmax responsibility) for new data. + +```cpp +void predict(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view precisions_chol, +raft::device_vector_view labels); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | in | `raft::resources const&` | The raft resources handle. | +| `params` | in | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | Fit hyper-parameters; only n_components and cov_type are consulted at inference time. | +| `X` | in | `raft::device_matrix_view` | Data to assign, row-major. [dim = n_samples x n_features] | +| `weights` | in | `raft::device_vector_view` | Fitted mixture weights. [len = n_components] | +| `means` | in | `raft::device_matrix_view` | Fitted component means. [dim = n_components x n_features] | +| `precisions_chol` | in | `raft::device_vector_view` | Fitted precision Cholesky factors, flat. Length by cov_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `labels` | out | `raft::device_vector_view` | Hard component assignment per sample. [len = n_samples] | + +**Returns** + +`void` + +**Additional overload:** `cluster::gmm::predict` + +```cpp +void predict(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view precisions_chol, +raft::device_vector_view labels); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | | `raft::resources const&` | | +| `params` | | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | | +| `X` | | `raft::device_matrix_view` | | +| `weights` | | `raft::device_vector_view` | | +| `means` | | `raft::device_matrix_view` | | +| `precisions_chol` | | `raft::device_vector_view` | | +| `labels` | | `raft::device_vector_view` | | + +**Returns** + +`void` + + +### cluster::gmm::predict_proba + +Posterior responsibilities for new data. + +```cpp +void predict_proba(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view precisions_chol, +raft::device_matrix_view resp); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | in | `raft::resources const&` | The raft resources handle. | +| `params` | in | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | Fit hyper-parameters; only n_components and cov_type are consulted at inference time. | +| `X` | in | `raft::device_matrix_view` | Data to evaluate, row-major. [dim = n_samples x n_features] | +| `weights` | in | `raft::device_vector_view` | Fitted mixture weights. [len = n_components] | +| `means` | in | `raft::device_matrix_view` | Fitted component means. [dim = n_components x n_features] | +| `precisions_chol` | in | `raft::device_vector_view` | Fitted precision Cholesky factors, flat. Length by cov_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `resp` | out | `raft::device_matrix_view` | Posterior probability of each component for each sample, row-major. [dim = n_samples x n_components] | + +**Returns** + +`void` + +**Additional overload:** `cluster::gmm::predict_proba` + +```cpp +void predict_proba(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view precisions_chol, +raft::device_matrix_view resp); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | | `raft::resources const&` | | +| `params` | | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | | +| `X` | | `raft::device_matrix_view` | | +| `weights` | | `raft::device_vector_view` | | +| `means` | | `raft::device_matrix_view` | | +| `precisions_chol` | | `raft::device_vector_view` | | +| `resp` | | `raft::device_matrix_view` | | + +**Returns** + +`void` + + +### cluster::gmm::score_samples + +Per-sample log-likelihood log p(x_i) for new data. + +```cpp +void score_samples(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view precisions_chol, +raft::device_vector_view log_prob_norm); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | in | `raft::resources const&` | The raft resources handle. | +| `params` | in | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | Fit hyper-parameters; only n_components and cov_type are consulted at inference time. | +| `X` | in | `raft::device_matrix_view` | Data to evaluate, row-major. [dim = n_samples x n_features] | +| `weights` | in | `raft::device_vector_view` | Fitted mixture weights. [len = n_components] | +| `means` | in | `raft::device_matrix_view` | Fitted component means. [dim = n_components x n_features] | +| `precisions_chol` | in | `raft::device_vector_view` | Fitted precision Cholesky factors, flat. Length by cov_type (K=n_components, d=n_features): FULL K*d*d, TIED d*d, DIAG K*d, SPHERICAL K. | +| `log_prob_norm` | out | `raft::device_vector_view` | Log-likelihood of each sample under the model. [len = n_samples] | + +**Returns** + +`void` + +**Additional overload:** `cluster::gmm::score_samples` + +```cpp +void score_samples(raft::resources const& handle, +const params& params, +raft::device_matrix_view X, +raft::device_vector_view weights, +raft::device_matrix_view means, +raft::device_vector_view precisions_chol, +raft::device_vector_view log_prob_norm); +``` + +**Parameters** + +| Name | Direction | Type | Description | +| --- | --- | --- | --- | +| `handle` | | `raft::resources const&` | | +| `params` | | [`const params&`](/api-reference/cpp-api-cluster-gmm#cluster-gmm-params) | | +| `X` | | `raft::device_matrix_view` | | +| `weights` | | `raft::device_vector_view` | | +| `means` | | `raft::device_matrix_view` | | +| `precisions_chol` | | `raft::device_vector_view` | | +| `log_prob_norm` | | `raft::device_vector_view` | | + +**Returns** + +`void` diff --git a/fern/pages/cpp_api/index.md b/fern/pages/cpp_api/index.md index e61927491a..d5904b5db4 100644 --- a/fern/pages/cpp_api/index.md +++ b/fern/pages/cpp_api/index.md @@ -5,6 +5,7 @@ These pages are generated from the documented public headers in the cuVS source ## Cluster - [Agglomerative](/api-reference/cpp-api-cluster-agglomerative) +- [Gmm](/api-reference/cpp-api-cluster-gmm) - [K-Means](/api-reference/cpp-api-cluster-kmeans) - [Spectral](/api-reference/cpp-api-cluster-spectral) diff --git a/fern/pages/python_api/index.md b/fern/pages/python_api/index.md index 8522cf0a8f..2b21c1d633 100644 --- a/fern/pages/python_api/index.md +++ b/fern/pages/python_api/index.md @@ -4,6 +4,7 @@ These pages are generated from the Python and Cython sources under `python/cuvs/ ## Cluster +- [Gmm](/api-reference/python-api-cluster-gmm) - [Kmeans](/api-reference/python-api-cluster-kmeans) ## Common diff --git a/fern/pages/python_api/python-api-cluster-gmm.md b/fern/pages/python_api/python-api-cluster-gmm.md new file mode 100644 index 0000000000..100bf95ba2 --- /dev/null +++ b/fern/pages/python_api/python-api-cluster-gmm.md @@ -0,0 +1,281 @@ +--- +slug: api-reference/python-api-cluster-gmm +--- + +# Gmm + +_Python module: `cuvs.cluster.gmm`_ + +## GMMParams + +```python +cdef class GMMParams +``` + +Hyper-parameters for the Gaussian mixture EM solver + +**Parameters** + +| Name | Type | Description | +| --- | --- | --- | +| `n_components` | `int` | The number of mixture components. | +| `covariance_type` | `str` | Covariance parameterization, one of "full", "tied", "diag", "spherical". Matches scikit-learn's ``GaussianMixture``. | +| `tol` | `float` | Convergence threshold on the change of the per-sample average log-likelihood (lower bound). | +| `reg_covar` | `float` | Non-negative regularization added to the diagonal of covariance. | +| `max_iter` | `int` | Maximum number of EM iterations for a single run. | +| `n_init` | `int` | Number of initializations to perform; the best result is kept. | +| `init_method` | `str` | Strategy used to initialize the responsibilities before EM. One of: "kmeans" : run k-means and use the hard labels "k-means++" : use the k-means++ seeding labels "random" : random responsibilities, normalized per sample "random_from_data" : pick n_components samples as one-hot responsibilities | +| `seed` | `int` | Seed to the random number generator. | + +**Constructor** + +```python +def __init__(self, *, n_components=None, covariance_type=None, tol=None, reg_covar=None, max_iter=None, n_init=None, init_method=None, seed=None) +``` + +**Members** + +| Name | Kind | +| --- | --- | +| `n_components` | property | +| `covariance_type` | property | +| `tol` | property | +| `reg_covar` | property | +| `max_iter` | property | +| `n_init` | property | +| `init_method` | property | +| `seed` | property | + +### n_components + +```python +def n_components(self) +``` + +### covariance_type + +```python +def covariance_type(self) +``` + +### tol + +```python +def tol(self) +``` + +### reg_covar + +```python +def reg_covar(self) +``` + +### max_iter + +```python +def max_iter(self) +``` + +### n_init + +```python +def n_init(self) +``` + +### init_method + +```python +def init_method(self) +``` + +### seed + +```python +def seed(self) +``` + +## fit + +`@auto_sync_resources` +`@auto_convert_output` + +```python +def fit(GMMParams params, X, weights=None, means=None, covariances=None, warm_start=False, resources=None) +``` + +Fit a Gaussian mixture model with the EM algorithm + +**Parameters** + +| Name | Type | Description | +| --- | --- | --- | +| `params` | `GMMParams` | Parameters of the EM solver. | +| `X` | `Input CUDA array interface compliant matrix shape (m, k)` | | +| `weights` | `Optional writable CUDA array interface vector, shape (n_components,). Holds the initial mixture weights when` | ``warm_start`` is True and receives the fitted weights. | +| `means` | `Optional writable CUDA array interface matrix, shape (n_components, k). Holds the initial means when` | ``warm_start`` is True and receives the fitted means. | +| `covariances` | `Optional writable CUDA array interface array whose shape` | depends on ``params.covariance_type`` ("full": (n_components, k, k), "tied": (k, k), "diag": (n_components, k), "spherical": (n_components,)). Holds the initial covariances when ``warm_start`` is True and receives the fitted covariances. | +| `warm_start` | `bool` | Use the provided weights/means/covariances as the single initialization instead of running ``params.n_init`` restarts. | +| `resources` | `cuvs.common.Resources, optional` | | + +**Returns** + +| Name | Type | Description | +| --- | --- | --- | +| `weights` | `raft.device_ndarray` | Fitted mixture weights, shape (n_components,) | +| `means` | `raft.device_ndarray` | Fitted component means, shape (n_components, k) | +| `covariances` | `raft.device_ndarray` | Fitted covariances (shape depends on covariance_type) | +| `precisions_chol` | `raft.device_ndarray` | Precision Cholesky factors (shape depends on covariance_type) | +| `precisions` | `raft.device_ndarray` | Precision matrices (shape depends on covariance_type) | +| `labels` | `raft.device_ndarray` | Hard component assignment per sample, shape (m,) | +| `lower_bound` | `float` | Per-sample average log-likelihood of the best fit | +| `n_iter` | `int` | Number of EM iterations of the best fit | +| `converged` | `bool` | Whether the best fit converged within ``params.tol`` | + +**Examples** + +```python +>>> import cupy as cp +>>> +>>> from cuvs.cluster.gmm import fit, GMMParams +>>> +>>> n_samples = 5000 +>>> n_features = 50 +>>> n_components = 3 +>>> +>>> X = cp.random.random_sample((n_samples, n_features), +... dtype=cp.float32) +``` + +```python +>>> params = GMMParams(n_components=n_components) +>>> weights, means, covariances, precisions_chol, *_ = fit(params, X) +``` + +## predict + +`@auto_sync_resources` +`@auto_convert_output` + +```python +def predict(GMMParams params, X, weights, means, precisions_chol, labels=None, resources=None) +``` + +Hard component labels (argmax responsibility) for new data + +**Parameters** + +| Name | Type | Description | +| --- | --- | --- | +| `params` | `GMMParams` | Parameters used to fit the GMM model. | +| `X` | `Input CUDA array interface compliant matrix shape (m, k)` | | +| `weights` | `Fitted mixture weights, shape (n_components,)` | | +| `means` | `Fitted component means, shape (n_components, k)` | | +| `precisions_chol` | `Fitted precision Cholesky factors (shape depends on covariance_type)` | | +| `labels` | `Optional preallocated CUDA array interface vector shape (m,)` | to hold the output (int32) | +| `resources` | `cuvs.common.Resources, optional` | | + +**Returns** + +| Name | Type | Description | +| --- | --- | --- | +| `labels` | `raft.device_ndarray` | Component assignment for each datapoint in X | + +**Examples** + +```python +>>> import cupy as cp +>>> +>>> from cuvs.cluster.gmm import fit, predict, GMMParams +>>> +>>> X = cp.random.random_sample((5000, 50), dtype=cp.float32) +>>> params = GMMParams(n_components=3) +>>> weights, means, covariances, precisions_chol, *_ = fit(params, X) +>>> +>>> labels = predict(params, X, weights, means, precisions_chol) +``` + +## predict_proba + +`@auto_sync_resources` +`@auto_convert_output` + +```python +def predict_proba(GMMParams params, X, weights, means, precisions_chol, resp=None, resources=None) +``` + +Posterior responsibilities for new data + +**Parameters** + +| Name | Type | Description | +| --- | --- | --- | +| `params` | `GMMParams` | Parameters used to fit the GMM model. | +| `X` | `Input CUDA array interface compliant matrix shape (m, k)` | | +| `weights` | `Fitted mixture weights, shape (n_components,)` | | +| `means` | `Fitted component means, shape (n_components, k)` | | +| `precisions_chol` | `Fitted precision Cholesky factors (shape depends on covariance_type)` | | +| `resp` | `Optional preallocated CUDA array interface matrix shape` | (m, n_components) to hold the output | +| `resources` | `cuvs.common.Resources, optional` | | + +**Returns** + +| Name | Type | Description | +| --- | --- | --- | +| `resp` | `raft.device_ndarray` | Posterior probability of each component for each sample | + +**Examples** + +```python +>>> import cupy as cp +>>> +>>> from cuvs.cluster.gmm import fit, predict_proba, GMMParams +>>> +>>> X = cp.random.random_sample((5000, 50), dtype=cp.float32) +>>> params = GMMParams(n_components=3) +>>> weights, means, covariances, precisions_chol, *_ = fit(params, X) +>>> +>>> resp = predict_proba(params, X, weights, means, precisions_chol) +``` + +## score_samples + +`@auto_sync_resources` +`@auto_convert_output` + +```python +def score_samples(GMMParams params, X, weights, means, precisions_chol, log_prob=None, resources=None) +``` + +Per-sample log-likelihood log p(x_i) for new data + +**Parameters** + +| Name | Type | Description | +| --- | --- | --- | +| `params` | `GMMParams` | Parameters used to fit the GMM model. | +| `X` | `Input CUDA array interface compliant matrix shape (m, k)` | | +| `weights` | `Fitted mixture weights, shape (n_components,)` | | +| `means` | `Fitted component means, shape (n_components, k)` | | +| `precisions_chol` | `Fitted precision Cholesky factors (shape depends on covariance_type)` | | +| `log_prob` | `Optional preallocated CUDA array interface vector shape (m,)` | to hold the output | +| `resources` | `cuvs.common.Resources, optional` | | + +**Returns** + +| Name | Type | Description | +| --- | --- | --- | +| `log_prob` | `raft.device_ndarray` | Log-likelihood of each sample under the model | + +**Examples** + +```python +>>> import cupy as cp +>>> +>>> from cuvs.cluster.gmm import fit, score_samples, GMMParams +>>> +>>> X = cp.random.random_sample((5000, 50), dtype=cp.float32) +>>> params = GMMParams(n_components=3) +>>> weights, means, covariances, precisions_chol, *_ = fit(params, X) +>>> +>>> log_prob = score_samples(params, X, weights, means, precisions_chol) +``` diff --git a/python/cuvs/cuvs/cluster/CMakeLists.txt b/python/cuvs/cuvs/cluster/CMakeLists.txt index 5e2df9a60b..fcd81ab9c8 100644 --- a/python/cuvs/cuvs/cluster/CMakeLists.txt +++ b/python/cuvs/cuvs/cluster/CMakeLists.txt @@ -1,8 +1,9 @@ # ============================================================================= # cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # cmake-format: on # ============================================================================= +add_subdirectory(gmm) add_subdirectory(kmeans) diff --git a/python/cuvs/cuvs/cluster/__init__.py b/python/cuvs/cuvs/cluster/__init__.py index ec29a2139a..485d21a7f3 100644 --- a/python/cuvs/cuvs/cluster/__init__.py +++ b/python/cuvs/cuvs/cluster/__init__.py @@ -1,7 +1,7 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -from cuvs.cluster import kmeans +from cuvs.cluster import gmm, kmeans -__all__ = ["kmeans"] +__all__ = ["gmm", "kmeans"] diff --git a/python/cuvs/cuvs/cluster/gmm/CMakeLists.txt b/python/cuvs/cuvs/cluster/gmm/CMakeLists.txt new file mode 100644 index 0000000000..37394f4a15 --- /dev/null +++ b/python/cuvs/cuvs/cluster/gmm/CMakeLists.txt @@ -0,0 +1,16 @@ +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources gmm.pyx) +set(linked_libraries cuvs::cuvs cuvs::c_api) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX cluster_gmm_ +) diff --git a/python/cuvs/cuvs/cluster/gmm/__init__.pxd b/python/cuvs/cuvs/cluster/gmm/__init__.pxd new file mode 100644 index 0000000000..8eca3cc68a --- /dev/null +++ b/python/cuvs/cuvs/cluster/gmm/__init__.pxd @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python/cuvs/cuvs/cluster/gmm/__init__.py b/python/cuvs/cuvs/cluster/gmm/__init__.py new file mode 100644 index 0000000000..467dbf1276 --- /dev/null +++ b/python/cuvs/cuvs/cluster/gmm/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + + +from .gmm import GMMParams, fit, predict, predict_proba, score_samples + +__all__ = ["GMMParams", "fit", "predict", "predict_proba", "score_samples"] diff --git a/python/cuvs/cuvs/cluster/gmm/gmm.pxd b/python/cuvs/cuvs/cluster/gmm/gmm.pxd new file mode 100644 index 0000000000..73a9bd4dfd --- /dev/null +++ b/python/cuvs/cuvs/cluster/gmm/gmm.pxd @@ -0,0 +1,79 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# +# cython: language_level=3 + +from libc.stdint cimport int64_t, uint64_t +from libcpp cimport bool + +from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t +from cuvs.common.cydlpack cimport DLManagedTensor + + +cdef extern from "cuvs/cluster/gmm.h" nogil: + ctypedef enum cuvsGMMCovarianceType: + CUVS_GMM_COVARIANCE_FULL + CUVS_GMM_COVARIANCE_TIED + CUVS_GMM_COVARIANCE_DIAG + CUVS_GMM_COVARIANCE_SPHERICAL + + ctypedef enum cuvsGMMInitMethod: + CUVS_GMM_INIT_KMEANS + CUVS_GMM_INIT_KMEANS_PLUS_PLUS + CUVS_GMM_INIT_RANDOM + CUVS_GMM_INIT_RANDOM_FROM_DATA + + ctypedef struct cuvsGMMParams: + int n_components, + cuvsGMMCovarianceType covariance_type, + double tol, + double reg_covar, + int max_iter, + int n_init, + cuvsGMMInitMethod init, + uint64_t seed + + ctypedef cuvsGMMParams* cuvsGMMParams_t + + cuvsError_t cuvsGMMParamsCreate(cuvsGMMParams_t* params) + + cuvsError_t cuvsGMMParamsDestroy(cuvsGMMParams_t params) + + cuvsError_t cuvsGMMFit(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* covariances, + DLManagedTensor* precisions_chol, + DLManagedTensor* precisions, + DLManagedTensor* labels, + double* lower_bound, + int* n_iter, + bool* converged, + bool warm_start) except + + + cuvsError_t cuvsGMMPredict(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* labels) except + + + cuvsError_t cuvsGMMPredictProba(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* resp) except + + + cuvsError_t cuvsGMMScoreSamples(cuvsResources_t res, + cuvsGMMParams_t params, + DLManagedTensor* X, + DLManagedTensor* weights, + DLManagedTensor* means, + DLManagedTensor* precisions_chol, + DLManagedTensor* log_prob_norm) except + diff --git a/python/cuvs/cuvs/cluster/gmm/gmm.pyx b/python/cuvs/cuvs/cluster/gmm/gmm.pyx new file mode 100644 index 0000000000..01b846b15e --- /dev/null +++ b/python/cuvs/cuvs/cluster/gmm/gmm.pyx @@ -0,0 +1,551 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# +# cython: language_level=3 + +from collections import namedtuple + +import numpy as np + +cimport cuvs.common.cydlpack + +from cuvs.common.resources import auto_sync_resources + +from libcpp cimport bool + +from cuvs.common cimport cydlpack + +from pylibraft.common import auto_convert_output, device_ndarray +from pylibraft.common.cai_wrapper import wrap_array +from pylibraft.common.interruptible import cuda_interruptible + +from cuvs.common.exceptions import check_cuvs +from cuvs.neighbors.common import _check_input_array + +COVARIANCE_TYPES = { + "full": cuvsGMMCovarianceType.CUVS_GMM_COVARIANCE_FULL, + "tied": cuvsGMMCovarianceType.CUVS_GMM_COVARIANCE_TIED, + "diag": cuvsGMMCovarianceType.CUVS_GMM_COVARIANCE_DIAG, + "spherical": cuvsGMMCovarianceType.CUVS_GMM_COVARIANCE_SPHERICAL, +} + +COVARIANCE_NAMES = {v: k for k, v in COVARIANCE_TYPES.items()} + +INIT_METHOD_TYPES = { + "kmeans": cuvsGMMInitMethod.CUVS_GMM_INIT_KMEANS, + "k-means++": cuvsGMMInitMethod.CUVS_GMM_INIT_KMEANS_PLUS_PLUS, + "random": cuvsGMMInitMethod.CUVS_GMM_INIT_RANDOM, + "random_from_data": cuvsGMMInitMethod.CUVS_GMM_INIT_RANDOM_FROM_DATA, +} + +INIT_METHOD_NAMES = {v: k for k, v in INIT_METHOD_TYPES.items()} + + +def _covariance_shape(covariance_type, n_components, n_features): + """Logical shape of the covariance-typed buffers (sklearn conventions).""" + if covariance_type == "full": + return (n_components, n_features, n_features) + elif covariance_type == "tied": + return (n_features, n_features) + elif covariance_type == "diag": + return (n_components, n_features) + else: # spherical + return (n_components,) + + +cdef class GMMParams: + """ + Hyper-parameters for the Gaussian mixture EM solver + + Parameters + ---------- + n_components : int + The number of mixture components. + covariance_type : str + Covariance parameterization, one of "full", "tied", "diag", + "spherical". Matches scikit-learn's ``GaussianMixture``. + tol : float + Convergence threshold on the change of the per-sample average + log-likelihood (lower bound). + reg_covar : float + Non-negative regularization added to the diagonal of covariance. + max_iter : int + Maximum number of EM iterations for a single run. + n_init : int + Number of initializations to perform; the best result is kept. + init_method : str + Strategy used to initialize the responsibilities before EM. One of: + "kmeans" : run k-means and use the hard labels + "k-means++" : use the k-means++ seeding labels + "random" : random responsibilities, normalized per sample + "random_from_data" : pick n_components samples as one-hot + responsibilities + seed : int + Seed to the random number generator. + """ + + cdef cuvsGMMParams* params + + def __cinit__(self): + check_cuvs(cuvsGMMParamsCreate(&self.params)) + + def __dealloc__(self): + if self.params is not NULL: + check_cuvs(cuvsGMMParamsDestroy(self.params)) + + def __init__(self, *, + n_components=None, + covariance_type=None, + tol=None, + reg_covar=None, + max_iter=None, + n_init=None, + init_method=None, + seed=None): + if n_components is not None: + self.params.n_components = n_components + if covariance_type is not None: + c_cov = COVARIANCE_TYPES[covariance_type] + self.params.covariance_type = c_cov + if tol is not None: + self.params.tol = tol + if reg_covar is not None: + self.params.reg_covar = reg_covar + if max_iter is not None: + self.params.max_iter = max_iter + if n_init is not None: + self.params.n_init = n_init + if init_method is not None: + c_init = INIT_METHOD_TYPES[init_method] + self.params.init = c_init + if seed is not None: + self.params.seed = seed + + @property + def n_components(self): + return self.params.n_components + + @property + def covariance_type(self): + return COVARIANCE_NAMES[self.params.covariance_type] + + @property + def tol(self): + return self.params.tol + + @property + def reg_covar(self): + return self.params.reg_covar + + @property + def max_iter(self): + return self.params.max_iter + + @property + def n_init(self): + return self.params.n_init + + @property + def init_method(self): + return INIT_METHOD_NAMES[self.params.init] + + @property + def seed(self): + return self.params.seed + + +FitOutput = namedtuple( + "FitOutput", + "weights means covariances precisions_chol precisions labels " + "lower_bound n_iter converged", +) + + +@auto_sync_resources +@auto_convert_output +def fit(GMMParams params, X, weights=None, means=None, covariances=None, + warm_start=False, resources=None): + """ + Fit a Gaussian mixture model with the EM algorithm + + Parameters + ---------- + params : GMMParams + Parameters of the EM solver. + X : Input CUDA array interface compliant matrix shape (m, k) + weights : Optional writable CUDA array interface vector, + shape (n_components,). Holds the initial mixture weights when + ``warm_start`` is True and receives the fitted weights. + means : Optional writable CUDA array interface matrix, + shape (n_components, k). Holds the initial means when + ``warm_start`` is True and receives the fitted means. + covariances : Optional writable CUDA array interface array whose shape + depends on ``params.covariance_type`` ("full": + (n_components, k, k), "tied": (k, k), "diag": + (n_components, k), "spherical": (n_components,)). Holds the + initial covariances when ``warm_start`` is True and receives + the fitted covariances. + warm_start : bool + Use the provided weights/means/covariances as the single + initialization instead of running ``params.n_init`` restarts. + {resources_docstring} + + Returns + ------- + weights : raft.device_ndarray + Fitted mixture weights, shape (n_components,) + means : raft.device_ndarray + Fitted component means, shape (n_components, k) + covariances : raft.device_ndarray + Fitted covariances (shape depends on covariance_type) + precisions_chol : raft.device_ndarray + Precision Cholesky factors (shape depends on covariance_type) + precisions : raft.device_ndarray + Precision matrices (shape depends on covariance_type) + labels : raft.device_ndarray + Hard component assignment per sample, shape (m,) + lower_bound : float + Per-sample average log-likelihood of the best fit + n_iter : int + Number of EM iterations of the best fit + converged : bool + Whether the best fit converged within ``params.tol`` + + Examples + -------- + + >>> import cupy as cp + >>> + >>> from cuvs.cluster.gmm import fit, GMMParams + >>> + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_components = 3 + >>> + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> params = GMMParams(n_components=n_components) + >>> weights, means, covariances, precisions_chol, *_ = fit(params, X) + """ + + x_ai = wrap_array(X) + _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + + n_samples = x_ai.shape[0] + n_features = x_ai.shape[1] + n_components = params.n_components + cov_shape = _covariance_shape( + params.covariance_type, n_components, n_features) + + if weights is None: + if warm_start: + raise ValueError("warm_start requires initial weights") + weights = device_ndarray.empty((n_components,), dtype=x_ai.dtype) + if means is None: + if warm_start: + raise ValueError("warm_start requires initial means") + means = device_ndarray.empty( + (n_components, n_features), dtype=x_ai.dtype) + if covariances is None: + if warm_start: + raise ValueError("warm_start requires initial covariances") + covariances = device_ndarray.empty(cov_shape, dtype=x_ai.dtype) + + precisions_chol = device_ndarray.empty(cov_shape, dtype=x_ai.dtype) + precisions = device_ndarray.empty(cov_shape, dtype=x_ai.dtype) + labels = device_ndarray.empty((n_samples,), dtype='int32') + + weights_ai = wrap_array(weights) + means_ai = wrap_array(means) + covariances_ai = wrap_array(covariances) + precisions_chol_ai = wrap_array(precisions_chol) + precisions_ai = wrap_array(precisions) + labels_ai = wrap_array(labels) + + _check_input_array(weights_ai, [x_ai.dtype], exp_rows=n_components) + _check_input_array(means_ai, [x_ai.dtype], exp_rows=n_components, + exp_cols=n_features) + _check_input_array(covariances_ai, [x_ai.dtype]) + _check_input_array(labels_ai, [np.dtype('int32')], exp_rows=n_samples) + + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) + cdef cydlpack.DLManagedTensor* weights_dlpack = \ + cydlpack.dlpack_c(weights_ai) + cdef cydlpack.DLManagedTensor* means_dlpack = cydlpack.dlpack_c(means_ai) + cdef cydlpack.DLManagedTensor* covariances_dlpack = \ + cydlpack.dlpack_c(covariances_ai) + cdef cydlpack.DLManagedTensor* precisions_chol_dlpack = \ + cydlpack.dlpack_c(precisions_chol_ai) + cdef cydlpack.DLManagedTensor* precisions_dlpack = \ + cydlpack.dlpack_c(precisions_ai) + cdef cydlpack.DLManagedTensor* labels_dlpack = cydlpack.dlpack_c(labels_ai) + + cdef cuvsResources_t res = resources.get_c_obj() + + cdef double lower_bound = 0 + cdef int n_iter = 0 + cdef bool converged = False + cdef bool c_warm_start = warm_start + + with cuda_interruptible(): + check_cuvs(cuvsGMMFit( + res, + params.params, + x_dlpack, + weights_dlpack, + means_dlpack, + covariances_dlpack, + precisions_chol_dlpack, + precisions_dlpack, + labels_dlpack, + &lower_bound, + &n_iter, + &converged, + c_warm_start)) + + return FitOutput(weights, means, covariances, precisions_chol, + precisions, labels, lower_bound, n_iter, + True if converged else False) + + +@auto_sync_resources +@auto_convert_output +def predict(GMMParams params, X, weights, means, precisions_chol, + labels=None, resources=None): + """ + Hard component labels (argmax responsibility) for new data + + Parameters + ---------- + params : GMMParams + Parameters used to fit the GMM model. + X : Input CUDA array interface compliant matrix shape (m, k) + weights : Fitted mixture weights, shape (n_components,) + means : Fitted component means, shape (n_components, k) + precisions_chol : Fitted precision Cholesky factors (shape depends on + covariance_type) + labels : Optional preallocated CUDA array interface vector shape (m,) + to hold the output (int32) + {resources_docstring} + + Returns + ------- + labels : raft.device_ndarray + Component assignment for each datapoint in X + + Examples + -------- + + >>> import cupy as cp + >>> + >>> from cuvs.cluster.gmm import fit, predict, GMMParams + >>> + >>> X = cp.random.random_sample((5000, 50), dtype=cp.float32) + >>> params = GMMParams(n_components=3) + >>> weights, means, covariances, precisions_chol, *_ = fit(params, X) + >>> + >>> labels = predict(params, X, weights, means, precisions_chol) + """ + + x_ai = wrap_array(X) + _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + + if labels is None: + labels = device_ndarray.empty((x_ai.shape[0],), dtype='int32') + + labels_ai = wrap_array(labels) + _check_input_array( + labels_ai, [np.dtype('int32')], exp_rows=x_ai.shape[0]) + + weights_ai = wrap_array(weights) + means_ai = wrap_array(means) + precisions_chol_ai = wrap_array(precisions_chol) + _check_input_array(weights_ai, [x_ai.dtype]) + _check_input_array(means_ai, [x_ai.dtype], exp_rows=params.n_components, + exp_cols=x_ai.shape[1]) + _check_input_array(precisions_chol_ai, [x_ai.dtype]) + + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) + cdef cydlpack.DLManagedTensor* weights_dlpack = \ + cydlpack.dlpack_c(weights_ai) + cdef cydlpack.DLManagedTensor* means_dlpack = cydlpack.dlpack_c(means_ai) + cdef cydlpack.DLManagedTensor* precisions_chol_dlpack = \ + cydlpack.dlpack_c(precisions_chol_ai) + cdef cydlpack.DLManagedTensor* labels_dlpack = cydlpack.dlpack_c(labels_ai) + + cdef cuvsResources_t res = resources.get_c_obj() + + with cuda_interruptible(): + check_cuvs(cuvsGMMPredict( + res, + params.params, + x_dlpack, + weights_dlpack, + means_dlpack, + precisions_chol_dlpack, + labels_dlpack)) + + return labels + + +@auto_sync_resources +@auto_convert_output +def predict_proba(GMMParams params, X, weights, means, precisions_chol, + resp=None, resources=None): + """ + Posterior responsibilities for new data + + Parameters + ---------- + params : GMMParams + Parameters used to fit the GMM model. + X : Input CUDA array interface compliant matrix shape (m, k) + weights : Fitted mixture weights, shape (n_components,) + means : Fitted component means, shape (n_components, k) + precisions_chol : Fitted precision Cholesky factors (shape depends on + covariance_type) + resp : Optional preallocated CUDA array interface matrix shape + (m, n_components) to hold the output + {resources_docstring} + + Returns + ------- + resp : raft.device_ndarray + Posterior probability of each component for each sample + + Examples + -------- + + >>> import cupy as cp + >>> + >>> from cuvs.cluster.gmm import fit, predict_proba, GMMParams + >>> + >>> X = cp.random.random_sample((5000, 50), dtype=cp.float32) + >>> params = GMMParams(n_components=3) + >>> weights, means, covariances, precisions_chol, *_ = fit(params, X) + >>> + >>> resp = predict_proba(params, X, weights, means, precisions_chol) + """ + + x_ai = wrap_array(X) + _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + + if resp is None: + resp = device_ndarray.empty( + (x_ai.shape[0], params.n_components), dtype=x_ai.dtype) + + resp_ai = wrap_array(resp) + _check_input_array(resp_ai, [x_ai.dtype], exp_rows=x_ai.shape[0], + exp_cols=params.n_components) + + weights_ai = wrap_array(weights) + means_ai = wrap_array(means) + precisions_chol_ai = wrap_array(precisions_chol) + _check_input_array(weights_ai, [x_ai.dtype]) + _check_input_array(means_ai, [x_ai.dtype], exp_rows=params.n_components, + exp_cols=x_ai.shape[1]) + _check_input_array(precisions_chol_ai, [x_ai.dtype]) + + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) + cdef cydlpack.DLManagedTensor* weights_dlpack = \ + cydlpack.dlpack_c(weights_ai) + cdef cydlpack.DLManagedTensor* means_dlpack = cydlpack.dlpack_c(means_ai) + cdef cydlpack.DLManagedTensor* precisions_chol_dlpack = \ + cydlpack.dlpack_c(precisions_chol_ai) + cdef cydlpack.DLManagedTensor* resp_dlpack = cydlpack.dlpack_c(resp_ai) + + cdef cuvsResources_t res = resources.get_c_obj() + + with cuda_interruptible(): + check_cuvs(cuvsGMMPredictProba( + res, + params.params, + x_dlpack, + weights_dlpack, + means_dlpack, + precisions_chol_dlpack, + resp_dlpack)) + + return resp + + +@auto_sync_resources +@auto_convert_output +def score_samples(GMMParams params, X, weights, means, precisions_chol, + log_prob=None, resources=None): + """ + Per-sample log-likelihood log p(x_i) for new data + + Parameters + ---------- + params : GMMParams + Parameters used to fit the GMM model. + X : Input CUDA array interface compliant matrix shape (m, k) + weights : Fitted mixture weights, shape (n_components,) + means : Fitted component means, shape (n_components, k) + precisions_chol : Fitted precision Cholesky factors (shape depends on + covariance_type) + log_prob : Optional preallocated CUDA array interface vector shape (m,) + to hold the output + {resources_docstring} + + Returns + ------- + log_prob : raft.device_ndarray + Log-likelihood of each sample under the model + + Examples + -------- + + >>> import cupy as cp + >>> + >>> from cuvs.cluster.gmm import fit, score_samples, GMMParams + >>> + >>> X = cp.random.random_sample((5000, 50), dtype=cp.float32) + >>> params = GMMParams(n_components=3) + >>> weights, means, covariances, precisions_chol, *_ = fit(params, X) + >>> + >>> log_prob = score_samples(params, X, weights, means, precisions_chol) + """ + + x_ai = wrap_array(X) + _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + + if log_prob is None: + log_prob = device_ndarray.empty((x_ai.shape[0],), dtype=x_ai.dtype) + + log_prob_ai = wrap_array(log_prob) + _check_input_array(log_prob_ai, [x_ai.dtype], exp_rows=x_ai.shape[0]) + + weights_ai = wrap_array(weights) + means_ai = wrap_array(means) + precisions_chol_ai = wrap_array(precisions_chol) + _check_input_array(weights_ai, [x_ai.dtype]) + _check_input_array(means_ai, [x_ai.dtype], exp_rows=params.n_components, + exp_cols=x_ai.shape[1]) + _check_input_array(precisions_chol_ai, [x_ai.dtype]) + + cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) + cdef cydlpack.DLManagedTensor* weights_dlpack = \ + cydlpack.dlpack_c(weights_ai) + cdef cydlpack.DLManagedTensor* means_dlpack = cydlpack.dlpack_c(means_ai) + cdef cydlpack.DLManagedTensor* precisions_chol_dlpack = \ + cydlpack.dlpack_c(precisions_chol_ai) + cdef cydlpack.DLManagedTensor* log_prob_dlpack = \ + cydlpack.dlpack_c(log_prob_ai) + + cdef cuvsResources_t res = resources.get_c_obj() + + with cuda_interruptible(): + check_cuvs(cuvsGMMScoreSamples( + res, + params.params, + x_dlpack, + weights_dlpack, + means_dlpack, + precisions_chol_dlpack, + log_prob_dlpack)) + + return log_prob diff --git a/python/cuvs/cuvs/tests/test_gmm.py b/python/cuvs/cuvs/tests/test_gmm.py new file mode 100644 index 0000000000..a70fb11ef8 --- /dev/null +++ b/python/cuvs/cuvs/tests/test_gmm.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# + +import numpy as np +import pytest +from pylibraft.common import device_ndarray +from sklearn.datasets import make_blobs +from sklearn.metrics import adjusted_rand_score +from sklearn.mixture import GaussianMixture as skGaussianMixture + +from cuvs.cluster.gmm import ( + GMMParams, + fit, + predict, + predict_proba, + score_samples, +) + +COVARIANCE_TYPES = ["full", "tied", "diag", "spherical"] +INIT_METHODS = ["kmeans", "k-means++", "random", "random_from_data"] +DTYPES = [np.float32, np.float64] + + +def _blobs(n_samples=1000, n_features=8, centers=4, seed=0, dtype=np.float32): + X, y = make_blobs( + n_samples=n_samples, + n_features=n_features, + centers=centers, + cluster_std=1.0, + random_state=seed, + ) + return np.ascontiguousarray(X.astype(dtype)), y + + +def _rel(dtype): + return 1e-2 if dtype == np.float32 else 1e-5 + + +@pytest.mark.parametrize("covariance_type", COVARIANCE_TYPES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_fit_matches_sklearn(covariance_type, dtype): + X, y = _blobs(dtype=dtype) + Xd = device_ndarray(X) + + params = GMMParams( + n_components=4, + covariance_type=covariance_type, + init_method="kmeans", + seed=0, + ) + out = fit(params, Xd) + weights, means, prec_chol = out[0], out[1], out[3] + n_iter, converged = out[7], out[8] + + assert n_iter >= 1 + # well-separated blobs with default max_iter converge + assert converged + + labels = predict(params, Xd, weights, means, prec_chol).copy_to_host() + score = float( + np.asarray( + score_samples(params, Xd, weights, means, prec_chol).copy_to_host() + ).mean() + ) + + sk = skGaussianMixture( + n_components=4, + covariance_type=covariance_type, + init_params="kmeans", + random_state=0, + ).fit(X) + + # Hard labels recover the ground-truth blobs and agree with sklearn. + assert adjusted_rand_score(y, labels) >= 0.95 + assert adjusted_rand_score(sk.predict(X), labels) >= 0.95 + # Per-sample average log-likelihood matches sklearn's GMM.score. + assert score == pytest.approx(sk.score(X), rel=_rel(dtype)) + + +@pytest.mark.parametrize("init_method", INIT_METHODS) +def test_init_methods_run(init_method): + X, y = _blobs() + Xd = device_ndarray(X) + params = GMMParams(n_components=4, init_method=init_method, seed=0) + out = fit(params, Xd) + labels = predict(params, Xd, out[0], out[1], out[3]).copy_to_host() + # kmeans-family inits recover the clusters; random inits at least run and + # return a valid labeling. + if init_method in ("kmeans", "k-means++"): + assert adjusted_rand_score(y, labels) >= 0.95 + assert set(np.unique(labels)).issubset(set(range(4))) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("covariance_type", COVARIANCE_TYPES) +def test_predict_proba_normalized(covariance_type, dtype): + X, _ = _blobs(dtype=dtype) + Xd = device_ndarray(X) + params = GMMParams(n_components=4, covariance_type=covariance_type, seed=0) + out = fit(params, Xd) + resp = predict_proba(params, Xd, out[0], out[1], out[3]).copy_to_host() + assert resp.shape == (X.shape[0], 4) + assert np.all(resp >= -1e-5) + np.testing.assert_allclose(resp.sum(axis=1), 1.0, atol=1e-3) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("covariance_type", COVARIANCE_TYPES) +def test_predict_matches_predict_proba_argmax(covariance_type, dtype): + X, _ = _blobs(dtype=dtype) + Xd = device_ndarray(X) + params = GMMParams(n_components=4, covariance_type=covariance_type, seed=0) + out = fit(params, Xd) + labels = predict(params, Xd, out[0], out[1], out[3]).copy_to_host() + resp = predict_proba(params, Xd, out[0], out[1], out[3]).copy_to_host() + np.testing.assert_array_equal(labels, resp.argmax(axis=1)) + + +def test_score_samples_matches_sklearn(): + X, _ = _blobs(dtype=np.float64) + Xd = device_ndarray(X) + params = GMMParams( + n_components=4, covariance_type="full", init_method="kmeans", seed=0 + ) + out = fit(params, Xd) + logp = score_samples(params, Xd, out[0], out[1], out[3]).copy_to_host() + + sk = skGaussianMixture( + n_components=4, + covariance_type="full", + init_params="kmeans", + random_state=0, + ).fit(X) + np.testing.assert_allclose( + np.asarray(logp), sk.score_samples(X), rtol=1e-4, atol=1e-4 + ) + + +def test_warm_start(): + X, _ = _blobs() + Xd = device_ndarray(X) + params = GMMParams( + n_components=4, + covariance_type="full", + init_method="kmeans", + max_iter=1, + seed=0, + ) + out = fit(params, Xd) + # Re-fit warm-started from the previous parameters; should run, return a + # finite lower bound, and not regress the objective below the prior fit. + out2 = fit( + params, + Xd, + weights=out[0], + means=out[1], + covariances=out[2], + warm_start=True, + ) + assert np.isfinite(out2[6]) + assert out2[6] >= out[6] - 1e-6