Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/tests/cluster/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam<KmeansBatchedInputs
params.tol = testparams.tol;
params.rng_state.seed = 1;
params.oversampling_factor = 0;
// Limit the number of iterations to ensure same number of iterations for reference and batched
// code paths.
params.max_iter = 3;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix doesn't seem particularly robust, and appears to me like it'll reduce the coverage for the tests, right? Default max_iters for k-means is much higher. How do we test in a way where we can validate its ability to stop early without having to reduce the max-iters to the point where we are essentially not testing the early stopping at all?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point of the test when we are bit matching centroids is to ensure that the accumulation of partial sums in the OOC setting in a single iteration matches the in-memory setting. The test fails when there is an edge case -- the variation due to floating point precision in the arithmetic is just enough to push the difference in centroids between consecutive iterations to fall below the convergence criteria -- in-memory converges one iteration sooner in than OOC.


auto stream = raft::resource::get_cuda_stream(handle);

Expand All @@ -402,8 +405,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam<KmeansBatchedInputs

auto d_sw = d_sw_view();

params.init = cuvs::cluster::kmeans::params::Array;
params.max_iter = 20;
params.init = cuvs::cluster::kmeans::params::Array;

T ref_inertia = 0;
int ref_n_iter = 0;
Expand Down
Loading