diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 8e46764ce4..dcec7ade8d 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -11,6 +11,11 @@ #include #include +#include +#include +#include +#include + #include "../core/exceptions.hpp" #include "../core/interop.hpp" @@ -213,10 +218,14 @@ void _cluster_cost(cuvsResources_t res, if (cuvs::core::is_dlpack_device_compatible(X)) { using mdspan_type = raft::device_matrix_view; + auto d_cost = raft::make_device_scalar(*res_ptr, T{0}); cuvs::cluster::kmeans::cluster_cost(*res_ptr, cuvs::core::from_dlpack(X_tensor), cuvs::core::from_dlpack(centroids_tensor), - raft::make_host_scalar_view(&cost_temp)); + d_cost.view()); + raft::copy( + *res_ptr, raft::make_host_scalar_view(&cost_temp), raft::make_const_mdspan(d_cost.view())); + raft::resource::sync_stream(*res_ptr); } else { RAFT_FAIL("X dataset must be accessible on device memory"); } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index e2b4ea4a36..9fab781372 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -1532,14 +1532,16 @@ void transform(raft::resources const& handle, * @param[out] cost Resulting cluster cost * @param[in] sample_weight Optional per-sample weights. * [len = n_samples] - * + * @param[in] X_norm Optional precomputed squared L2 row norms of X (||x||^2) [n_samples]. + * When provided, the internal norm computation is skipped. */ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight = std::nullopt); + raft::device_scalar_view cost, + std::optional> sample_weight = std::nullopt, + std::optional> X_norm = std::nullopt); /** * @brief Compute cluster cost @@ -1554,13 +1556,17 @@ void cluster_cost( * @param[out] cost Resulting cluster cost * @param[in] sample_weight Optional per-sample weights. * [len = n_samples] + * @param[in] X_norm Optional precomputed squared L2 row norms of X (||x||^2, + * i.e. sum of squares without the sqrt) [n_samples]. When + * provided, the internal norm computation is skipped. */ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight = std::nullopt); + raft::device_scalar_view cost, + std::optional> sample_weight = std::nullopt, + std::optional> X_norm = std::nullopt); /** * @brief Compute (optionally weighted) cluster cost @@ -1575,13 +1581,17 @@ void cluster_cost( * @param[out] cost Resulting cluster cost * @param[in] sample_weight Optional per-sample weights. * [len = n_samples] + * @param[in] X_norm Optional precomputed squared L2 row norms of X (||x||^2, + * i.e. sum of squares without the sqrt) [n_samples]. When + * provided, the internal norm computation is skipped. */ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight = std::nullopt); + raft::device_scalar_view cost, + std::optional> sample_weight = std::nullopt, + std::optional> X_norm = std::nullopt); /** * @brief Compute (optionally weighted) cluster cost @@ -1596,13 +1606,17 @@ void cluster_cost( * @param[out] cost Resulting cluster cost * @param[in] sample_weight Optional per-sample weights. * [len = n_samples] + * @param[in] X_norm Optional precomputed squared L2 row norms of X (||x||^2, + * i.e. sum of squares without the sqrt) [n_samples]. When + * provided, the internal norm computation is skipped. */ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight = std::nullopt); + raft::device_scalar_view cost, + std::optional> sample_weight = std::nullopt, + std::optional> X_norm = std::nullopt); /** * @} */ diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 635e8813bd..4235bc0198 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -931,7 +931,11 @@ void kmeans_fit( auto centroids_const = raft::make_device_matrix_view( cur_centroids_ptr, n_clusters, n_features); - iter_inertia = DataT{0}; + auto d_iter_inertia = raft::make_device_scalar(handle, DataT{0}); + auto d_batch_cost = raft::make_device_scalar(handle, DataT{0}); + DataT* p_acc = d_iter_inertia.data_handle(); + DataT* p_batch = d_batch_cost.data_handle(); + data_batches.reset(); using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn; std::optional wt_it; @@ -956,15 +960,33 @@ void kmeans_fit( cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); } - DataT batch_cost = DataT{0}; - cuvs::cluster::kmeans::cluster_cost(handle, - batch_data_view, - centroids_const, - raft::make_host_scalar_view(&batch_cost), - batch_sw); + std::optional> batch_xnorm = std::nullopt; + if (need_compute_norms) { + if constexpr (data_on_device) { + batch_xnorm = raft::make_device_vector_view( + L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size); + } else if (norms_cached) { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data_handle() + data_batch.offset(), + cur_batch_size, + stream); + batch_xnorm = raft::make_device_vector_view( + L2NormBatch.data_handle(), cur_batch_size); + } + } + + cuvs::cluster::kmeans::cluster_cost( + handle, batch_data_view, centroids_const, d_batch_cost.view(), batch_sw, batch_xnorm); - iter_inertia += batch_cost; + raft::linalg::map_offset(handle, + raft::make_device_vector_view(p_acc, 1), + [p_acc, p_batch] __device__(int) { return *p_acc + *p_batch; }); } + + raft::copy(handle, + raft::make_host_scalar_view(&iter_inertia), + raft::make_const_mdspan(d_iter_inertia.view())); + raft::resource::sync_stream(handle); } if (iter_inertia < inertia[0]) { diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index a290f7372f..6a80aa51c9 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1139,8 +1139,12 @@ void build_hierarchical(const raft::resources& handle, reinterpret_cast(dataset), n_rows, dim); auto centroids_view = raft::make_device_matrix_view(cluster_centers, n_clusters, dim); - cuvs::cluster::kmeans::cluster_cost( - handle, X_view, centroids_view, raft::make_host_scalar_view(inertia)); + auto d_inertia = raft::make_device_scalar(handle, MathT{0}); + cuvs::cluster::kmeans::cluster_cost(handle, X_view, centroids_view, d_inertia.view()); + raft::copy(handle, + raft::make_host_scalar_view(inertia), + raft::make_const_mdspan(d_inertia.view())); + raft::resource::sync_stream(handle, stream); } else { RAFT_LOG_WARN("Inertia is not computed for non float/double types"); } diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 06da1fc1de..6aa0e200bd 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -382,6 +382,8 @@ void min_cluster_distance(raft::resources const& handle, * @param[in] centroids Cluster centroids [n_clusters x n_features] * @param[out] cost Sum of squared distances to nearest centroid (device) * @param[in] sample_weight Optional per-sample weights [n_samples] + * @param[in] X_norm Optional precomputed squared L2 row norms of X (||x||^2) [n_samples]. + * When provided, the internal norm computation is skipped. */ template void cluster_cost( @@ -389,7 +391,8 @@ void cluster_cost( raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_scalar_view cost, - std::optional> sample_weight = std::nullopt) + std::optional> sample_weight = std::nullopt, + std::optional> X_norm = std::nullopt) { auto stream = raft::resource::get_cuda_stream(handle); auto n_clusters = centroids.extent(0); @@ -398,8 +401,18 @@ void cluster_cost( rmm::device_uvector workspace(n_samples * sizeof(IndexT), stream); - auto x_norms = raft::make_device_vector(handle, n_samples); - raft::linalg::norm(handle, X, x_norms.view()); + std::optional> x_norms_buf; + DataT* x_norms_ptr; + if (X_norm.has_value()) { + RAFT_EXPECTS(X_norm->extent(0) == n_samples, "X_norm size !=n_samples"); + x_norms_ptr = const_cast(X_norm->data_handle()); + } else { + x_norms_buf.emplace(raft::make_device_vector(handle, n_samples)); + raft::linalg::norm( + handle, X, x_norms_buf->view()); + x_norms_ptr = x_norms_buf->data_handle(); + } + auto x_norms_view = raft::make_device_vector_view(x_norms_ptr, n_samples); auto min_cluster_distance = raft::make_device_vector(handle, n_samples); rmm::device_uvector l2_norm_or_distance_buffer(0, stream); @@ -412,7 +425,7 @@ void cluster_cost( raft::make_device_matrix_view( const_cast(centroids.data_handle()), n_clusters, n_features), min_cluster_distance.view(), - x_norms.view(), + x_norms_view, l2_norm_or_distance_buffer, metric, n_samples, @@ -431,34 +444,6 @@ void cluster_cost( handle, min_cluster_distance.view(), workspace, cost, raft::add_op{}); } -/** - * @brief Compute (optionally weighted) cluster cost (inertia) — host-scalar output. - * - * Convenience wrapper that copies the result to host and synchronizes. - * - * @tparam DataT float or double - * @tparam IndexT Index type - * - * @param[in] handle The raft handle - * @param[in] X Input data [n_samples x n_features] - * @param[in] centroids Cluster centroids [n_clusters x n_features] - * @param[out] cost Sum of squared distances to nearest centroid (host) - * @param[in] sample_weight Optional per-sample weights [n_samples] - */ -template -void cluster_cost( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight = std::nullopt) -{ - auto device_cost = raft::make_device_scalar(handle, DataT(0)); - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, device_cost.view(), sample_weight); - raft::copy(handle, cost, raft::make_const_mdspan(device_cost.view())); - raft::resource::sync_stream(handle); -} - /** * @brief Calculates a pair for every sample in input 'X' where key is an * index of one of the 'centroids' (index of the nearest centroid) and 'value' diff --git a/cpp/src/cluster/kmeans_cluster_cost.cu b/cpp/src/cluster/kmeans_cluster_cost.cu index 0cdc182fb9..713e74d101 100644 --- a/cpp/src/cluster/kmeans_cluster_cost.cu +++ b/cpp/src/cluster/kmeans_cluster_cost.cu @@ -11,36 +11,44 @@ namespace cuvs::cluster::kmeans { void cluster_cost(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight) + raft::device_scalar_view cost, + std::optional> sample_weight, + std::optional> X_norm) { - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); + cuvs::cluster::kmeans::cluster_cost( + handle, X, centroids, cost, sample_weight, X_norm); } void cluster_cost(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight) + raft::device_scalar_view cost, + std::optional> sample_weight, + std::optional> X_norm) { - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); + cuvs::cluster::kmeans::cluster_cost( + handle, X, centroids, cost, sample_weight, X_norm); } void cluster_cost(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight) + raft::device_scalar_view cost, + std::optional> sample_weight, + std::optional> X_norm) { - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); + cuvs::cluster::kmeans::cluster_cost( + handle, X, centroids, cost, sample_weight, X_norm); } void cluster_cost(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, - std::optional> sample_weight) + raft::device_scalar_view cost, + std::optional> sample_weight, + std::optional> X_norm) { - cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, cost, sample_weight); + cuvs::cluster::kmeans::cluster_cost( + handle, X, centroids, cost, sample_weight, X_norm); } } // namespace cuvs::cluster::kmeans diff --git a/fern/pages/cpp_api/cpp-api-cluster-kmeans.md b/fern/pages/cpp_api/cpp-api-cluster-kmeans.md index cbb5a73de7..57c44d41b8 100644 --- a/fern/pages/cpp_api/cpp-api-cluster-kmeans.md +++ b/fern/pages/cpp_api/cpp-api-cluster-kmeans.md @@ -967,8 +967,9 @@ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, -raft::host_scalar_view cost, -std::optional> sample_weight = std::nullopt); +raft::device_scalar_view cost, +std::optional> sample_weight = std::nullopt, +std::optional> X_norm = std::nullopt); ``` **Parameters** @@ -978,8 +979,9 @@ std::optional> sample_weight = std::n | `handle` | in | `const raft::resources&` | The raft handle | | `X` | in | `raft::device_matrix_view` | Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features] | | `centroids` | in | `raft::device_matrix_view` | Cluster centroids. The data must be in row-major format. [dim = n_clusters x n_features] | -| `cost` | out | `raft::host_scalar_view` | Resulting cluster cost | +| `cost` | out | `raft::device_scalar_view` | Resulting cluster cost | | `sample_weight` | in | `std::optional>` | Optional per-sample weights. [len = n_samples]
Default: `std::nullopt`. | +| `X_norm` | in | `std::optional>` | Optional precomputed L2 norms of X rows [n_samples]. When provided, the internal norm computation is skipped.
Default: `std::nullopt`. | **Returns** @@ -994,8 +996,9 @@ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, -raft::host_scalar_view cost, -std::optional> sample_weight = std::nullopt); +raft::device_scalar_view cost, +std::optional> sample_weight = std::nullopt, +std::optional> X_norm = std::nullopt); ``` **Parameters** @@ -1005,8 +1008,9 @@ std::optional> sample_weight = std:: | `handle` | in | `const raft::resources&` | The raft handle | | `X` | in | `raft::device_matrix_view` | Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features] | | `centroids` | in | `raft::device_matrix_view` | Cluster centroids. The data must be in row-major format. [dim = n_clusters x n_features] | -| `cost` | out | `raft::host_scalar_view` | Resulting cluster cost | +| `cost` | out | `raft::device_scalar_view` | Resulting cluster cost | | `sample_weight` | in | `std::optional>` | Optional per-sample weights. [len = n_samples]
Default: `std::nullopt`. | +| `X_norm` | in | `std::optional>` | Optional precomputed L2 norms of X rows [n_samples]. When provided, the internal norm computation is skipped.
Default: `std::nullopt`. | **Returns** @@ -1021,8 +1025,9 @@ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, -raft::host_scalar_view cost, -std::optional> sample_weight = std::nullopt); +raft::device_scalar_view cost, +std::optional> sample_weight = std::nullopt, +std::optional> X_norm = std::nullopt); ``` **Parameters** @@ -1032,8 +1037,9 @@ std::optional> sample_weight = st | `handle` | in | `const raft::resources&` | The raft handle | | `X` | in | `raft::device_matrix_view` | Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features] | | `centroids` | in | `raft::device_matrix_view` | Cluster centroids. The data must be in row-major format. [dim = n_clusters x n_features] | -| `cost` | out | `raft::host_scalar_view` | Resulting cluster cost | +| `cost` | out | `raft::device_scalar_view` | Resulting cluster cost | | `sample_weight` | in | `std::optional>` | Optional per-sample weights. [len = n_samples]
Default: `std::nullopt`. | +| `X_norm` | in | `std::optional>` | Optional precomputed L2 norms of X rows [n_samples]. When provided, the internal norm computation is skipped.
Default: `std::nullopt`. | **Returns** @@ -1048,8 +1054,9 @@ void cluster_cost( const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, -raft::host_scalar_view cost, -std::optional> sample_weight = std::nullopt); +raft::device_scalar_view cost, +std::optional> sample_weight = std::nullopt, +std::optional> X_norm = std::nullopt); ``` **Parameters** @@ -1059,8 +1066,9 @@ std::optional> sample_weight = s | `handle` | in | `const raft::resources&` | The raft handle | | `X` | in | `raft::device_matrix_view` | Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features] | | `centroids` | in | `raft::device_matrix_view` | Cluster centroids. The data must be in row-major format. [dim = n_clusters x n_features] | -| `cost` | out | `raft::host_scalar_view` | Resulting cluster cost | +| `cost` | out | `raft::device_scalar_view` | Resulting cluster cost | | `sample_weight` | in | `std::optional>` | Optional per-sample weights. [len = n_samples]
Default: `std::nullopt`. | +| `X_norm` | in | `std::optional>` | Optional precomputed L2 norms of X rows [n_samples]. When provided, the internal norm computation is skipped.
Default: `std::nullopt`. | **Returns** diff --git a/fern/pages/cpp_api/cpp-api-neighbors-common.md b/fern/pages/cpp_api/cpp-api-neighbors-common.md index 68430cf6fb..6a6b122a9b 100644 --- a/fern/pages/cpp_api/cpp-api-neighbors-common.md +++ b/fern/pages/cpp_api/cpp-api-neighbors-common.md @@ -147,7 +147,8 @@ Filtering for ANN Types enum class FilterType { None, Bitmap, - Bitset + Bitset, + UDF }; ``` @@ -158,6 +159,7 @@ enum class FilterType { | `None` | `` | | `Bitmap` | `` | | `Bitset` | `` | +| `UDF` | `` | ### neighbors::filtering::none_sample_filter::operator @@ -318,6 +320,33 @@ FilterType get_filter_type() const override; [`FilterType`](/api-reference/cpp-api-neighbors-common#neighbors-filtering-filtertype) + +### neighbors::filtering::udf_filter + +JIT-LTO user-defined filter predicate. + +The source must define a device function named by `function_name` with signature: + +Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment. + +```cpp +struct udf_filter : public base_filter { + std::string source; + void* filter_data; + float filtering_rate; + std::string function_name; +}; +``` + +**Fields** + +| Name | Type | Description | +| --- | --- | --- | +| `source` | `std::string` | CUDA C++ source containing the device predicate. | +| `filter_data` | `void*` | Opaque device-accessible pointer passed to the predicate. | +| `filtering_rate` | `float` | Estimated fraction of rows rejected by the predicate, or negative if unknown. | +| `function_name` | `std::string` | Device function name to call from the generated CAGRA sample filter. | + ## ANN MG index build parameters