Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
195 commits
Select commit Hold shift + click to select a range
66d7fd3
combine impls
tarang-jain Apr 10, 2026
07707af
Multi-GPU Batched KMeans
viclafargue Apr 13, 2026
efc270f
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 13, 2026
0a09e6f
rm inertia_check
tarang-jain Apr 13, 2026
99a5730
change to warning
tarang-jain Apr 13, 2026
a077406
style
tarang-jain Apr 13, 2026
d659875
add init_size param
tarang-jain Apr 13, 2026
ec2e8b7
Merge branch 'main' into combine-batch
tarang-jain Apr 13, 2026
03a6473
docs
tarang-jain Apr 13, 2026
42a8d9d
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 13, 2026
86af2fa
rm direct cuda api calls
tarang-jain Apr 13, 2026
d4e4e2c
std::swap instead of raft::copy
tarang-jain Apr 14, 2026
0819af5
cache batch norms
tarang-jain Apr 14, 2026
e0f079c
centroid norms can also be cached per iteration
tarang-jain Apr 14, 2026
c2f7390
mg n_iter
tarang-jain Apr 14, 2026
b9c3102
pre-commit
tarang-jain Apr 14, 2026
e3956c1
do not break c abi
tarang-jain Apr 14, 2026
986d78a
Merge branch 'main' into combine-batch
tarang-jain Apr 14, 2026
7197b71
cluster_cost on device
viclafargue Apr 14, 2026
84ab315
Updated testing
viclafargue Apr 14, 2026
47d4b94
templating
viclafargue Apr 15, 2026
a8e1d26
Merge branch 'main' into combine-batch
tarang-jain Apr 16, 2026
384d054
fix checkWeight
tarang-jain Apr 21, 2026
455b286
merge upstream:
tarang-jain Apr 21, 2026
5462809
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 21, 2026
6ba759c
fix compilation
tarang-jain Apr 21, 2026
e76eaac
rel_tol
tarang-jain Apr 22, 2026
afbefdf
pass workspace
tarang-jain Apr 22, 2026
e62a63c
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 22, 2026
e4f08bf
style
tarang-jain Apr 22, 2026
6e4a8f0
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 22, 2026
4a8a85c
do not use batch scratch space; rm update_centroids
tarang-jain Apr 22, 2026
bbf2a9f
move the debug log
tarang-jain Apr 22, 2026
410092c
add new suffixed param struct
tarang-jain Apr 22, 2026
c515c1e
address pr reviews
tarang-jain Apr 22, 2026
e8e63ab
fix docstring
tarang-jain Apr 22, 2026
30c457c
fix wt_sum warning
tarang-jain Apr 22, 2026
ab96623
rm deprecationwarning and instead add FutureWarning:=
tarang-jain Apr 22, 2026
269f23c
unweighted to never materialize batch weights
tarang-jain Apr 22, 2026
80a22ca
add cpp tests
tarang-jain Apr 23, 2026
ac06b05
update cpp tests
tarang-jain Apr 23, 2026
855624a
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 23, 2026
0a6748d
refactor
viclafargue Apr 23, 2026
7055272
rename to mnmg_fit
viclafargue Apr 23, 2026
0569340
revert batch norms cache
tarang-jain Apr 23, 2026
8cac63a
increase zero cost threshold
tarang-jain Apr 24, 2026
f6df4ae
apply cuda event plus re-add h_norm_cache
tarang-jain Apr 24, 2026
9fc74b1
rm cosine expanded stuff
tarang-jain Apr 24, 2026
dec3dc4
resolve merge conflicts
tarang-jain Apr 28, 2026
0d030a2
change suffix of the params struct
tarang-jain Apr 28, 2026
b1c034e
replace 06 by 08, add todo and note
tarang-jain Apr 28, 2026
a482495
update to v2
tarang-jain Apr 28, 2026
8ecfdc1
avoid stream sync inside weight sum
tarang-jain Apr 29, 2026
1e1525e
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
ec22e07
empty
tarang-jain Apr 29, 2026
d2e410d
empty
tarang-jain Apr 29, 2026
b791c38
Merge branch 'main' into combine-batch
tarang-jain Apr 29, 2026
a05a006
new signatures with new struct
tarang-jain Apr 29, 2026
73293cf
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
880c7b9
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 30, 2026
e2035ec
revert change to calls in py and rust; add c tests
tarang-jain Apr 30, 2026
e28c200
Merge branch 'main' into combine-batch
tarang-jain May 1, 2026
55bbdad
use to_dlpack
tarang-jain May 5, 2026
9a9b8ee
cache device weights
tarang-jain May 5, 2026
a800b27
rm event
tarang-jain May 5, 2026
3db8582
update names
tarang-jain May 5, 2026
c048352
rename
tarang-jain May 5, 2026
2f968f8
rm docs
tarang-jain May 5, 2026
affe85a
empty
tarang-jain May 5, 2026
c6dea64
fix norm cache
tarang-jain May 5, 2026
7dfab3e
revert changes to minClusterDistanceCompute
tarang-jain May 6, 2026
7a383da
update tests to use mdspan instead of rmm
tarang-jain May 6, 2026
ce6c4b5
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
5a06a44
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
419619a
consolidate all unsigned commits
tarang-jain May 7, 2026
2d716ae
rm diff
tarang-jain May 7, 2026
066092b
allow batch sample weights
tarang-jain May 7, 2026
bbdd66d
Merge branch 'main' into mnmg-streaming
tarang-jain May 7, 2026
12d682c
single partition becomes special case
tarang-jain May 7, 2026
9e5e55c
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 7, 2026
28cda6a
Merge branch 'combine-batch' into mg-batched-kmeans
viclafargue May 7, 2026
bfb5290
Addressing review
viclafargue May 7, 2026
add9db1
optimize convergence check
viclafargue May 7, 2026
6c08a7b
Merge branch 'main' into mnmg-streaming
tarang-jain May 7, 2026
acbcd5a
Merge branch 'main' into mnmg-streaming
tarang-jain May 7, 2026
af606bc
Adressing review
viclafargue May 8, 2026
41c66b8
Merge branch 'main' into mg-batched-kmeans
viclafargue May 8, 2026
f664c2c
results on all ranks for RAFT + small optimization
viclafargue May 8, 2026
5430f42
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
b2ab5bd
merge origin
tarang-jain May 8, 2026
bbdf521
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 8, 2026
10e6def
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
2040145
reduce diff
tarang-jain May 8, 2026
1828462
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 8, 2026
5c5b8c8
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
05da5f3
rm prefetch
tarang-jain May 8, 2026
90435c1
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 8, 2026
db41338
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
6c2c03d
reviews
viclafargue May 11, 2026
7f6d664
Global sampling for init
viclafargue May 11, 2026
f8270e2
SNMG -> MNMG
viclafargue May 11, 2026
bbf0302
Merge branch 'main' into mg-batched-kmeans
viclafargue May 11, 2026
a14a6bc
adding asserts
viclafargue May 11, 2026
7b54a42
consume new init
tarang-jain May 11, 2026
d86b8b4
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 11, 2026
6e11f67
reduce diff
tarang-jain May 11, 2026
9f5b6e5
Merge branch 'main' into mnmg-streaming
tarang-jain May 13, 2026
aaef638
rm unnecessary functions
tarang-jain May 13, 2026
920a460
Merge branch 'main' into mnmg-streaming
tarang-jain May 13, 2026
548d7db
rm accessor templates for now
tarang-jain May 14, 2026
9f3a486
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 14, 2026
c93f248
Merge branch 'main' of https://github.com/rapidsai/cuvs into mnmg-str…
tarang-jain May 14, 2026
51fbf6c
merge upstream
tarang-jain May 20, 2026
d327569
cleanup; re-add device side overload
tarang-jain May 20, 2026
b5e66a3
re-instate removed docs
tarang-jain May 20, 2026
a636188
rm extra fit funcs
tarang-jain May 21, 2026
d3cafed
Merge branch 'release/26.06' into mnmg-streaming
tarang-jain May 26, 2026
1b547f4
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 26, 2026
72cfd43
cleanup
tarang-jain May 26, 2026
4d25e95
rm scaled_weights_cache
tarang-jain May 26, 2026
81155e6
rm unnecessary new types
tarang-jain May 26, 2026
85522aa
rm unused helper
tarang-jain May 26, 2026
00336b5
rm unnecessary stream sync
tarang-jain May 27, 2026
6585866
rm unnecessary lambda
tarang-jain May 27, 2026
aa6f28e
cleanup impl
tarang-jain May 27, 2026
178a7e7
rm unnecessary has_data guards
tarang-jain May 27, 2026
713bc7c
rm global_n host scalar
tarang-jain May 27, 2026
c576d8f
fixes
tarang-jain May 27, 2026
a401a0e
Merge branch 'release/26.06' into mnmg-streaming
tarang-jain May 27, 2026
8102596
fuse with in-memory impl
tarang-jain May 27, 2026
00d0adb
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 27, 2026
caefd53
style
tarang-jain May 27, 2026
8f6f83d
fix compilation
tarang-jain May 27, 2026
d88a991
Merge branch 'release/26.06' into mnmg-streaming
tarang-jain May 28, 2026
1b57b74
mg tests first commit
tarang-jain May 28, 2026
7bac418
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 28, 2026
9851017
merge upstream
tarang-jain May 28, 2026
f572877
update cmakelists
tarang-jain May 29, 2026
edaa7e7
merge upstream
tarang-jain May 29, 2026
588bb6a
rm batched tests
tarang-jain May 29, 2026
ad180ed
Merge branch 'main' into mnmg-streaming
tarang-jain May 29, 2026
72cc34b
rm unnecessary test stream sycns
tarang-jain May 29, 2026
ed50703
reset bs; assertion
tarang-jain May 29, 2026
28f6036
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 29, 2026
a811c56
rm has_data flag
tarang-jain May 29, 2026
95f334c
Merge branch 'main' into mnmg-streaming
tarang-jain May 30, 2026
d176314
fix export
tarang-jain May 30, 2026
1db9e02
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 30, 2026
089e970
avoid pinned scalar;get_nccl_comms before omp
tarang-jain May 31, 2026
6cc895c
use root from macro
tarang-jain Jun 1, 2026
4abe6f2
avoid copy and rank alloc with initarray
tarang-jain Jun 1, 2026
ebf188a
Merge branch 'main' of https://github.com/rapidsai/cuvs into mnmg-str…
tarang-jain Jun 1, 2026
785e4a3
fix compilation; guardrail MG CMake flag
tarang-jain Jun 1, 2026
9a526c8
get n_features from centroids
tarang-jain Jun 1, 2026
f08e581
add sigs to header
tarang-jain Jun 1, 2026
51efb42
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 5, 2026
3e3cac7
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 8, 2026
7ffae6d
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 9, 2026
447e136
put single mdspan in mg namespace
tarang-jain Jun 9, 2026
50d0359
Merge branch 'mnmg-streaming' of https://github.com/tarang-jain/cuvs …
tarang-jain Jun 9, 2026
b1e7521
single reduce
tarang-jain Jun 9, 2026
f8a3503
re-add distributed kmeanspp
tarang-jain Jun 9, 2026
272a9d5
revert differences from distributed init
tarang-jain Jun 9, 2026
1f9fd9e
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 9, 2026
ed705c6
optional allredice
tarang-jain Jun 9, 2026
c8135e5
Merge branch 'mnmg-streaming' of https://github.com/tarang-jain/cuvs …
tarang-jain Jun 9, 2026
cf5e831
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 9, 2026
8b6eab2
apply distributed init fixes
tarang-jain Jun 10, 2026
100f7d6
Merge branch 'mnmg-streaming' of https://github.com/tarang-jain/cuvs …
tarang-jain Jun 10, 2026
46d18d1
re-add host assertion for sum of weights
tarang-jain Jun 10, 2026
db23a72
remaining samples should be global
tarang-jain Jun 10, 2026
1e7c119
doc updates
tarang-jain Jun 10, 2026
116a6cf
add dask-like cpp test
tarang-jain Jun 10, 2026
ae284be
add warning to docs
tarang-jain Jun 11, 2026
d15cb28
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 11, 2026
653aac1
address pr reviews
tarang-jain Jun 11, 2026
92588ba
Merge branch 'mnmg-streaming' of https://github.com/tarang-jain/cuvs …
tarang-jain Jun 11, 2026
ef60e3c
add weight check for sg device weights
tarang-jain Jun 11, 2026
8831789
re-update the docs
tarang-jain Jun 11, 2026
3f800a1
update docs
tarang-jain Jun 11, 2026
32be863
undo fern changes
tarang-jain Jun 11, 2026
d0fa3a9
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 11, 2026
0b969c4
address reviews
tarang-jain Jun 12, 2026
040e82b
Merge branch 'mnmg-streaming' of https://github.com/tarang-jain/cuvs …
tarang-jain Jun 12, 2026
d65287d
fix bcast
tarang-jain Jun 12, 2026
3321926
style
tarang-jain Jun 12, 2026
de57f46
update docstring
tarang-jain Jun 12, 2026
1a6bbb5
add oversampling test
tarang-jain Jun 15, 2026
e627af4
address reviews
tarang-jain Jun 16, 2026
fd19b43
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 16, 2026
993b813
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 22, 2026
bd3e1de
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 23, 2026
40c77ca
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 24, 2026
079c216
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 24, 2026
7a38f62
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1304,10 +1304,10 @@ if(NOT BUILD_CPU_ONLY)
src/cluster/detail/minClusterDistanceCompute.cu
src/cluster/agglomerative.cu
src/cluster/kmeans_cluster_cost.cu
src/cluster/kmeans_fit_mg_float.cu
src/cluster/kmeans_fit_mg_double.cu
Comment thread
tarang-jain marked this conversation as resolved.
src/cluster/kmeans_fit_double.cu
src/cluster/kmeans_fit_float.cu
$<$<BOOL:${BUILD_MG_ALGOS}>:src/cluster/kmeans_fit_mg_float.cu>
$<$<BOOL:${BUILD_MG_ALGOS}>:src/cluster/kmeans_fit_mg_double.cu>
src/cluster/kmeans_auto_find_k_float.cu
src/cluster/kmeans_fit_predict_double.cu
src/cluster/kmeans_fit_predict_float.cu
Expand Down Expand Up @@ -1451,8 +1451,9 @@ if(NOT BUILD_CPU_ONLY)
)

target_compile_definitions(
cuvs_objs PRIVATE $<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:CUVS_BUILD_CAGRA_HNSWLIB>
$<$<BOOL:${CUVS_NVTX}>:NVTX_ENABLED>
cuvs_objs
PRIVATE $<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:CUVS_BUILD_CAGRA_HNSWLIB>
$<$<BOOL:${BUILD_MG_ALGOS}>:CUVS_BUILD_MG_ALGOS> $<$<BOOL:${CUVS_NVTX}>:NVTX_ENABLED>
)

target_link_libraries(
Expand Down Expand Up @@ -1490,8 +1491,6 @@ if(NOT BUILD_CPU_ONLY)
)
find_package(NCCL REQUIRED)
target_link_libraries(cuvs_objs PUBLIC $<BUILD_LOCAL_INTERFACE:NCCL::NCCL>)

target_compile_definitions(cuvs_objs PUBLIC CUVS_BUILD_MG_ALGOS)
endif()

set(CUVS_CUSOLVER_DEPENDENCY CUDA::cusolver${_ctk_static_suffix})
Expand Down Expand Up @@ -1526,8 +1525,9 @@ if(NOT BUILD_CPU_ONLY)
"$<$<AND:$<COMPILE_LANGUAGE:CUDA>,$<CONFIG:Debug>>:${CUVS_DEBUG_CUDA_FLAGS}>"
)
target_compile_definitions(
cuvs PUBLIC $<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:CUVS_BUILD_CAGRA_HNSWLIB>
$<$<BOOL:${CUVS_NVTX}>:NVTX_ENABLED>
cuvs
PUBLIC $<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:CUVS_BUILD_CAGRA_HNSWLIB>
$<$<BOOL:${BUILD_MG_ALGOS}>:CUVS_BUILD_MG_ALGOS> $<$<BOOL:${CUVS_NVTX}>:NVTX_ENABLED>
)

target_link_libraries(
Expand Down Expand Up @@ -1577,8 +1577,9 @@ SECTIONS

target_compile_options(cuvs_static PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${CUVS_CXX_FLAGS}>")
target_compile_definitions(
cuvs_static PUBLIC $<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:CUVS_BUILD_CAGRA_HNSWLIB>
$<$<BOOL:${CUVS_NVTX}>:NVTX_ENABLED>
cuvs_static
PUBLIC $<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:CUVS_BUILD_CAGRA_HNSWLIB>
$<$<BOOL:${BUILD_MG_ALGOS}>:CUVS_BUILD_MG_ALGOS> $<$<BOOL:${CUVS_NVTX}>:NVTX_ENABLED>
)

target_include_directories(cuvs_static INTERFACE "$<INSTALL_INTERFACE:include>")
Expand Down
278 changes: 256 additions & 22 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cuvs/core/export.hpp>
#include <optional>
#include <vector>

namespace CUVS_EXPORT cuvs {
namespace cluster {
Expand Down Expand Up @@ -94,7 +95,15 @@ struct params : base_params {
int n_init = 1;

/**
* Oversampling factor for use in the k-means|| algorithm
* Oversampling factor for use in the k-means|| algorithm.
*
* In the single-GPU path the value `0` is overloaded as an algorithm switch
* that selects the classic sequential k-means++ instead of the scalable
* variant. Any value `> 0` is used as-is.
*
* In the multi-GPU path (@ref cuvs::cluster::kmeans::mg::fit with device-
* resident inputs) any value `< 1.0` (including `0`) is internally clamped to `1.0`.
* Values `>= 1.0` are passed through unchanged.
*/
double oversampling_factor = 2.0;

Expand Down Expand Up @@ -137,11 +146,13 @@ struct params : base_params {
/**
* Number of samples to process per GPU batch when fitting with host data.
* When set to 0, defaults to n_samples (process all at once).
* Only used by the batched (host-data) code path and ignored by device-data
* overloads.
* Only used by the batched (host-data) code path and ignored by
* device-data overloads.
*
* In multi-GPU mode, this is a per-rank batch size. Each rank processes up to
* this many local samples per batch, clamped to that rank's local sample count.
* In multi-GPU mode this is a per-rank batch size: each rank processes up
* to this many local samples per batch, clamped to that rank's local sample
* count. When MG is invoked with device partitions, the runtime ignores
* `streaming_batch_size` and processes each partition in full.
* Default: 0 (process all data at once).
*/
int64_t streaming_batch_size = 0;
Expand Down Expand Up @@ -179,26 +190,22 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };

/**
* @brief Find clusters with k-means algorithm using batched processing of host data.
* Single-GPU only.
*
* Multi-GPU migration (breaking change in cuVS 26.08): earlier releases
* silently dispatched this single-GPU overload to a multi-GPU implementation
* when the supplied RAFT handle had RAFT comms or an SNMG clique attached. That
* implicit dispatch has been removed: this overload is now strictly
* single-GPU. If `handle` carries communications/clique state it is ignored and the call falls back
* to the single-GPU path. To run on multiple GPUs, call `cuvs::cluster::kmeans::mg::fit`
* explicitly.
*
* TODO: Evaluate replacing the extent type with int64_t. Reference issue:
* https://github.com/rapidsai/cuvs/issues/1961
*
* This overload supports out-of-core computation where the dataset resides
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
* The batch size is controlled by params.streaming_batch_size. In multi-GPU mode,
* this is a per-rank batch size.
*
* Multi-GPU dispatch is selected automatically based on the handle state:
* - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X
* is split across GPUs internally with an OpenMP parallel region and NCCL.
* - If `raft::resource::comms_initialized(handle)` (Dask/Ray/MPI): X is treated as
* this worker's partition, and RAFT communicators are used for collectives.
* - Otherwise: single-GPU batched k-means.
*
* With `params.init == InitMethod::KMeansPlusPlus` in multi-GPU mode, the
* effective initialization sample must fit in GPU memory on every rank because
* it is materialized on every device. Rank 0 must also have enough GPU memory
* for the seeding workspace before centroids are broadcast.
* on the host. Data is processed in batches, streaming from host to
* device. The batch size is controlled by `params.streaming_batch_size`.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
Expand Down Expand Up @@ -229,8 +236,7 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* raft::make_host_scalar_view(&n_iter));
* @endcode
*
* @param[in] handle The raft handle. When a multi-GPU resource is
* attached, multi-GPU dispatch is used automatically.
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model. Batch size is read from
* params.streaming_batch_size.
* @param[in] X Training instances on HOST memory. The data must
Expand Down Expand Up @@ -1607,6 +1613,234 @@ void cluster_cost(
* @}
*/

#ifdef CUVS_BUILD_MG_ALGOS
namespace mg {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One compatibility question: the public header now declares cuvs::cluster::kmeans::mg::fit unconditionally, while cpp/CMakeLists.txt only compiles kmeans_fit_mg_float.cu and kmeans_fit_mg_double.cu under BUILD_MG_ALGOS. If a non-MG build installs this header, callers can compile against mg::fit but fail at link time. Would it be better to guard those declarations or provide stubs with a clear error?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! I'll add the guarding macro in the header itself.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also thinking about removing the ::mg namespace and rather naming the API cuvs::cluster::kmeans_multi_gpu or renaming that namespace to multi_gpu. That makes it more explicit. I think abbreviations like MG, SNMG and MNMG are more internal, rather than user facing.

/**
* @defgroup kmeans_mg Multi-GPU / out-of-core k-means fit
* @{
*
* @brief Explicit multi-GPU k-means entry points.
*
* All multi-GPU k-means APIs live in this namespace
* (`cuvs::cluster::kmeans::mg`). To run k-means on multiple GPUs a `handle` that
* carries either an SNMG clique (`raft::resource::is_multi_gpu(handle)`) or
* initialized RAFT comms (`raft::resource::comms_initialized(handle)`) must be used.
*
* Migration from earlier releases (breaking change in cuVS 26.08): before
* this release, the single-GPU `cuvs::cluster::kmeans::fit` overloads would
* silently dispatch to the multi-GPU backend when the supplied `handle`
* carried RAFT comms or an SNMG clique. That implicit dispatch has been
* removed: the single-GPU `fit` is now strictly single-GPU. Existing
* multi-GPU call sites must be updated to invoke
* `cuvs::cluster::kmeans::mg::fit` directly. Two flavors of the multi-GPU
* API are provided here:
* - A single mdspan per rank (drop-in replacement for the old single-GPU
* signature; cuVS wraps it into a one-element vector internally).
* - A `std::vector` of mdspan partitions per rank (multiple partitions per
* rank, typical for Dask/Ray and out-of-core host-data flows).
*
* @code{.cpp}
* // Before (cuVS <= 26.06): implicit multi-GPU dispatch via single-GPU API.
* // raft::resources handle; // attached to NCCL comms or SNMG clique
* // cuvs::cluster::kmeans::fit(handle, params, local_X, std::nullopt,
* // centroids, inertia, n_iter);
*
* // After (cuVS >= 26.08): explicit multi-GPU API.
* raft::resources handle; // NCCL comms or SNMG clique attached
* cuvs::cluster::kmeans::mg::fit(handle, params, local_X, std::nullopt,
* centroids, inertia, n_iter);
* @endcode
*/

/**
* @brief Multi-GPU k-means fit with one or more local data
* partitions per rank.
*
* Each rank supplies its local training data as a vector of partitions. For
* host-resident partitions the implementation streams each partition through
* Lloyd iterations using `params.streaming_batch_size` (per rank). For
* device-resident partitions `streaming_batch_size` is ignored and each local
* partition is processed in full.
*
* The active backend is selected by the resources attached to
* `handle`:
* - When `raft::resource::is_multi_gpu(handle)` is true (SNMG clique), the
* call must be issued from inside an OpenMP region with one thread per
* rank in the clique.
* - Otherwise, multi-process NCCL comms must be initialized on the handle
* (`raft::resource::comms_initialized(handle)`); each process supplies its
* own local partitions.
*
* @param[in] handle The raft handle. Must have NCCL comms or
* a SNMG clique initialized.
* @param[in] params K-means parameters. For host-resident
* partitions the per-rank streaming batch
* size is read from
* `params.streaming_batch_size`; it is
* ignored for device-resident partitions.
* @param[in] X_parts Per-partition local data on this rank.
* Each entry is [n_rows_i x n_features].
* @param[in] sample_weight_parts Optional per-partition row weights with
* one vector per data partition.
* @param[inout] centroids Device matrix [n_clusters x n_features].
* On entry, used as the initial centers
* when `params.init == InitMethod::Array`.
* On return, holds the converged
* centroids.
* @param[out] inertia Host scalar receiving the final
* clustering cost.
* @param[out] n_iter Host scalar receiving the iteration
* count at which the run terminated.
*/
void fit(
raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const float, int>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const float, int>>>& sample_weight_parts,
raft::device_matrix_view<float, int> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Multi-GPU k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const float, int64_t>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const float, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const double, int>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const double, int>>>&
sample_weight_parts,
raft::device_matrix_view<double, int> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Multi-GPU k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const double, int64_t>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const double, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU / out-of-core k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::host_matrix_view<const float, int64_t>>& X_parts,
const std::optional<std::vector<raft::host_vector_view<const float, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU / out-of-core k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::host_matrix_view<const double, int64_t>>& X_parts,
const std::optional<std::vector<raft::host_vector_view<const double, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU k-means fit, single mdspan per rank.
*
* Convenience overload for the common case where each rank has exactly one
* local partition. The mdspan is wrapped in a one-element vector and routed
* through the vector-of-partitions overload above. See that overload's
* documentation for backend selection and handle requirements.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<float, int> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Multi-GPU k-means fit, single mdspan per rank.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::device_matrix_view<const float, int64_t> X,
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU k-means fit, single mdspan per rank.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<double, int> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Multi-GPU k-means fit, single mdspan per rank.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::device_matrix_view<const double, int64_t> X,
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU / out-of-core k-means fit, single mdspan per rank.
*
* Dispatches to the SNMG-clique (batched per-rank) backend when the handle
* carries an SNMG clique, and to the NCCL multi-process backend otherwise.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int64_t> X,
std::optional<raft::host_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU / out-of-core k-means fit, single mdspan per rank.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const double, int64_t> X,
std::optional<raft::host_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @}
*/
} // namespace mg
#endif

namespace helpers {
/**
* @defgroup kmeans_helpers k-means API helpers
Expand Down
Loading
Loading