From 8565fe1dd6de938d650045692e5e175529293574 Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Tue, 30 Jun 2026 11:43:07 +0200 Subject: [PATCH 01/10] feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: embedding clustering feat: embedding clustering feat: embedding clustering feat: embedding clustering feat: checkpoint feat: checkpoint feat: checkpoint fix: merge feat: checkpoint feat: checkpoint feat: checkpoint fix: merge feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint feat: checkpoint] feat: checkpoint] feat: checkpoint] feat: checkpoint feat: checkpoint --- libs/@local/graph/api/src/rest/entity.rs | 67 +- .../store/postgres/knowledge/entity/mod.rs | 191 ++- .../graph/store/src/embedding/clustering.rs | 1437 +++++++++++++++++ .../graph/store/src/embedding/dimension.rs | 79 + .../graph/store/src/embedding/kernel.rs | 769 +++++++++ libs/@local/graph/store/src/embedding/mod.rs | 15 + libs/@local/graph/store/src/entity/mod.rs | 16 +- libs/@local/graph/store/src/entity/store.rs | 74 +- libs/@local/graph/store/src/error.rs | 15 + libs/@local/graph/store/src/lib.rs | 5 + libs/@local/graph/type-fetcher/src/store.rs | 20 +- tests/graph/integration/postgres/lib.rs | 11 + 12 files changed, 2665 insertions(+), 34 deletions(-) create mode 100644 libs/@local/graph/store/src/embedding/clustering.rs create mode 100644 libs/@local/graph/store/src/embedding/dimension.rs create mode 100644 libs/@local/graph/store/src/embedding/kernel.rs create mode 100644 libs/@local/graph/store/src/embedding/mod.rs diff --git a/libs/@local/graph/api/src/rest/entity.rs b/libs/@local/graph/api/src/rest/entity.rs index dcc4846f143..c1210d5be46 100644 --- a/libs/@local/graph/api/src/rest/entity.rs +++ b/libs/@local/graph/api/src/rest/entity.rs @@ -10,16 +10,16 @@ use hash_graph_postgres_store::store::error::{EntityDoesNotExist, RaceConditionO use hash_graph_store::{ self, entity::{ - ClosedMultiEntityTypeMap, CreateEntityParams, DiffEntityParams, DiffEntityResult, - EntityPermissions, EntityQueryCursor, EntityQuerySortingRecord, EntityQuerySortingToken, - EntityQueryToken, EntityStore, EntityTypesError, EntityValidationReport, - EntityValidationType, HasPermissionForEntitiesParams, LinkDataStateError, - LinkDataValidationReport, LinkError, LinkTargetError, LinkValidationReport, - LinkedEntityError, MetadataValidationReport, PatchEntityParams, - PropertyMetadataValidationReport, QueryConversion, QueryEntitiesResponse, - SearchEntitiesFilter, SearchEntitiesResponse, SummarizeEntitiesParams, - SummarizeEntitiesResponse, UnexpectedEntityType, UpdateEntityEmbeddingsParams, - ValidateEntityComponents, ValidateEntityParams, + ClosedMultiEntityTypeMap, ClusterEntitiesParams, ClusterEntitiesResponse, + CreateEntityParams, DiffEntityParams, DiffEntityResult, EntityCluster, EntityPermissions, + EntityQueryCursor, EntityQuerySortingRecord, EntityQuerySortingToken, EntityQueryToken, + EntityStore, EntityTypesError, EntityValidationReport, EntityValidationType, + HasPermissionForEntitiesParams, LinkDataStateError, LinkDataValidationReport, LinkError, + LinkTargetError, LinkValidationReport, LinkedEntityError, MetadataValidationReport, + PatchEntityParams, PropertyMetadataValidationReport, QueryConversion, + QueryEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, + UnexpectedEntityType, UpdateEntityEmbeddingsParams, ValidateEntityComponents, + ValidateEntityParams, }, entity_type::EntityTypeResolveDefinitions, pool::StorePool, @@ -98,6 +98,7 @@ use crate::rest::{ summarize_entities, patch_entity, update_entity_embeddings, + cluster_entities, diff_entity, ), components( @@ -115,6 +116,9 @@ use crate::rest::{ Embedding, UpdateEntityEmbeddingsParams, EntityEmbedding, + ClusterEntitiesParams, + ClusterEntitiesResponse, + EntityCluster, EntityQueryToken, PatchEntityParams, @@ -228,7 +232,12 @@ impl EntityResource { .route("/bulk", post(create_entities::)) .route("/diff", post(diff_entity::)) .route("/validate", post(validate_entity::)) - .route("/embeddings", post(update_entity_embeddings::)) + .nest( + "/embeddings", + Router::new() + .route("/", post(update_entity_embeddings::)) + .route("/clusters", post(cluster_entities::)), + ) .route("/permissions", post(has_permission_for_entities::)) .route("/search", post(search_entities::)) .nest( @@ -780,6 +789,42 @@ where .map_err(report_to_response) } +#[utoipa::path( + post, + path = "/entities/embeddings/clusters", + tag = "Entity", + params( + ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), + ), + responses( + (status = 200, content_type = "application/json", description = "Clusters of entities by embedding similarity", body = ClusterEntitiesResponse), + (status = 422, content_type = "text/plain", description = "Provided request body is invalid"), + + (status = 500, description = "Store error occurred"), + ), + request_body = ClusterEntitiesParams, +)] +async fn cluster_entities( + AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, + store_pool: Extension>, + temporal_client: Extension>>, + Json(params): Json, +) -> Result, BoxedResponse> +where + S: StorePool + Send + Sync, +{ + let store = store_pool + .acquire(temporal_client.0) + .await + .map_err(report_to_response)?; + + store + .cluster_entities(actor_id, params) + .await + .map_err(report_to_response) + .map(Json) +} + #[utoipa::path( post, path = "/entities/diff", diff --git a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs index 93e56376c68..869bbc46ef8 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs @@ -2,6 +2,7 @@ mod delete; mod query; mod read; mod summary; + use alloc::borrow::Cow; use core::{borrow::Borrow as _, mem}; use std::collections::{HashMap, HashSet}; @@ -17,18 +18,21 @@ use hash_graph_authorization::policies::{ store::{PolicyCreationParams, PrincipalStore as _}, }; use hash_graph_store::{ + embedding::dimension::Dimension, entity::{ - CreateEntityParams, DeleteEntitiesParams, DeletionSummary, EmptyEntityTypes, - EntityPermissions, EntityQueryCursor, EntityQueryPath, EntityQuerySorting, EntityStore, - EntityTypeRetrieval, EntityTypesError, EntityValidationReport, EntityValidationType, - HasPermissionForEntitiesParams, PatchEntityParams, QueryConversion, QueryEntitiesParams, - QueryEntitiesResponse, QueryEntitySubgraphParams, QueryEntitySubgraphResponse, - SearchEntitiesFilter, SearchEntitiesParams, SearchEntitiesResponse, - SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, - ValidateEntityComponents, ValidateEntityParams, + ClusterEntitiesParams, ClusterEntitiesResponse, CreateEntityParams, DeleteEntitiesParams, + DeletionSummary, EmptyEntityTypes, EntityCluster, EntityPermissions, EntityQueryCursor, + EntityQueryPath, EntityQuerySorting, EntityStore, EntityTypeRetrieval, EntityTypesError, + EntityValidationReport, EntityValidationType, HasPermissionForEntitiesParams, + PatchEntityParams, QueryConversion, QueryEntitiesParams, QueryEntitiesResponse, + QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SummarizeEntitiesParams, + SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityComponents, + ValidateEntityParams, }, entity_type::{EntityTypeStore as _, IncludeEntityTypeOption}, - error::{CheckPermissionError, DeletionError, InsertionError, QueryError, UpdateError}, + error::{ + CheckPermissionError, ClusterError, DeletionError, InsertionError, QueryError, UpdateError, + }, filter::{ Filter, FilterExpression, FilterExpressionList, Parameter, ParameterList, protection::transform_filter, @@ -2549,6 +2553,175 @@ where Ok(permitted_ids) } + + #[expect(clippy::too_many_lines)] + #[tracing::instrument(skip(self, params))] + async fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: ClusterEntitiesParams, + ) -> Result> { + // 3072 fits in u16; compile-time verified. + const { + assert!(Embedding::DIM <= u16::MAX as usize); + } + #[expect( + clippy::cast_possible_truncation, + reason = "guarded by the const assertion above" + )] + const STORED_DIM: u16 = Embedding::DIM as u16; + + let dim = Dimension::new(params.dimension).ok_or_else(|| { + Report::new(ClusterError::InvalidDimension { + dimension: params.dimension, + }) + })?; + + if dim.get() > STORED_DIM { + return Err(Report::new(ClusterError::DimensionTooLarge { + dimension: dim.get(), + max: STORED_DIM, + })); + } + let truncated_dim = usize::from(dim.get()); + + // Filter to entities the actor is allowed to view. + let permitted = self + .has_permission_for_entities( + AuthenticatedActor::from(actor_id), + HasPermissionForEntitiesParams { + action: ActionName::ViewEntity, + entity_ids: Cow::Borrowed(¶ms.entity_ids), + temporal_axes: QueryTemporalAxesUnresolved::TransactionTime { + pinned: PinnedTemporalAxisUnresolved::new(None), + variable: VariableTemporalAxisUnresolved::new(None, None), + }, + include_drafts: false, + }, + ) + .await + .change_context(ClusterError::Store)?; + + let permitted_ids: Vec = params + .entity_ids + .iter() + .filter(|&id| permitted.contains_key(id)) + .copied() + .collect(); + + let entity_uuids: Vec = permitted_ids.iter().map(|id| id.entity_uuid).collect(); + let web_ids: Vec = permitted_ids.iter().map(|id| id.web_id).collect(); + + // Truncate server-side via `subvector` so postgres only sends + // `truncated_dim`-dimensional vectors over the wire. + let row_stream = self + .as_client() + .query_raw( + &format!( + "SELECT + e.web_id, + e.entity_uuid, + subvector(e.embedding, 1, {truncated_dim})::vector({truncated_dim}) AS \ + embedding + FROM entity_embeddings e + WHERE e.property IS NULL + AND (e.web_id, e.entity_uuid) IN (SELECT unnest($1::uuid[]), \ + unnest($2::uuid[]))" + ), + [ + &web_ids as &(dyn ToSql + Sync), + &entity_uuids as &(dyn ToSql + Sync), + ], + ) + .instrument(tracing::info_span!( + "cluster_entities.embeddings", + otel.kind = "client", + db.system = "postgresql", + peer.service = "Postgres", + )) + .await + .change_context(ClusterError::Store)?; + + let mut row_stream = core::pin::pin!(row_stream); + + let mut flat: Vec = Vec::with_capacity(permitted_ids.len() * truncated_dim); + let mut found_ids: Vec = Vec::with_capacity(permitted_ids.len()); + + while let Some(row) = row_stream + .try_next() + .await + .change_context(ClusterError::Store)? + { + let web_id: WebId = row.get(0); + let entity_uuid: EntityUuid = row.get(1); + let embedding: Embedding<'_> = row.get(2); + + flat.extend(embedding.iter()); + + found_ids.push(EntityId { + web_id, + entity_uuid, + draft_id: None, + }); + } + + // Every requested entity not in a cluster goes into + // `missing_embeddings`, whether due to permissions or no embedding. + // Distinguishing the two would leak permission information. + let found_set: HashSet<(WebId, EntityUuid)> = found_ids + .iter() + .map(|id| (id.web_id, id.entity_uuid)) + .collect(); + let missing_embeddings: Vec = params + .entity_ids + .into_iter() + .filter(|id| !found_set.contains(&(id.web_id, id.entity_uuid))) + .collect(); + + if found_ids.is_empty() || params.cluster_count == 0 { + return Ok(ClusterEntitiesResponse { + clusters: Vec::new(), + missing_embeddings, + }); + } + + let config = hash_graph_store::embedding::clustering::Config::for_k_with_seed( + params.cluster_count, + params.seed.unwrap_or_else(|| { + std::time::SystemTime::UNIX_EPOCH + .elapsed() + .map_or(0, |elapsed| { + #[expect( + clippy::cast_possible_truncation, + reason = "seed only needs entropy, truncation is fine" + )] + let seed = elapsed.as_nanos() as u64; + seed + }) + }), + ); + + let result = hash_graph_store::embedding::clustering::cluster(&flat, dim, &config); + + let mut groups: HashMap> = HashMap::new(); + for (index, id) in found_ids.iter().enumerate() { + groups.entry(result.label(index)).or_default().push(*id); + } + + let clusters = groups + .into_iter() + .map(|(cluster_id, entity_ids)| EntityCluster { + cluster_id, + entity_ids, + centroid: result.centroid(cluster_id).to_vec(), + }) + .collect(); + + Ok(ClusterEntitiesResponse { + clusters, + missing_embeddings, + }) + } } #[derive(Debug)] diff --git a/libs/@local/graph/store/src/embedding/clustering.rs b/libs/@local/graph/store/src/embedding/clustering.rs new file mode 100644 index 00000000000..f48da06398a --- /dev/null +++ b/libs/@local/graph/store/src/embedding/clustering.rs @@ -0,0 +1,1437 @@ +use alloc::borrow::Cow; +use core::{cmp, mem, num::NonZero}; + +use rand::{Rng, RngExt as _, SeedableRng as _}; +use rand_xoshiro::Xoshiro256PlusPlus; +use rayon::prelude::*; + +use super::{dimension::Dimension, kernel}; + +/// Parameters for k-means clustering. +/// +/// Use [`Config::for_k`] or [`Config::for_k_with_seed`] to construct with +/// reasonable defaults, then override individual fields as needed. +pub struct Config { + /// Number of clusters. + pub k: u16, + + /// Maximum Lloyd iterations per run before declaring convergence. + pub max_iters: NonZero, + + /// Number of independent restarts. The run with the lowest inertia wins. + pub n_init: NonZero, + + /// Convergence tolerance: a run stops early when the relative change in + /// inertia between iterations falls below this value. + pub tol: f32, + + /// Maximum number of points sampled during k-means++ seeding. + /// Capped to avoid quadratic seeding cost on very large datasets. + pub sample_cap: usize, + + /// Base seed for the PRNG. Each restart derives its own seed from this. + pub seed: u64, + + /// Number of points processed per batch in the assignment step. + pub chunk: NonZero, +} + +impl Config { + /// Creates a configuration for `k` clusters, drawing the seed from `rng`. + #[must_use] + pub(crate) fn for_k(k: u16, mut rng: impl Rng) -> Self { + Self::for_k_with_seed(k, rng.random()) + } + + /// Creates a configuration for `k` clusters with a fixed seed. + /// + /// Defaults: 30 max iterations, 5 restarts, 1e-4 convergence tolerance, + /// sample cap of min(256k, 8192), chunk size 256. + #[must_use] + pub fn for_k_with_seed(k: u16, seed: u64) -> Self { + Self { + k, + max_iters: const { NonZero::new(30).unwrap() }, + n_init: const { NonZero::new(5).unwrap() }, + tol: 1e-4, + sample_cap: cmp::min(256 * usize::from(k), 8192), + seed, + chunk: const { NonZero::new(256).unwrap() }, + } + } +} + +/// Result of spherical k-means clustering. +/// +/// `centroids` is a flat `k * d` row-major buffer where `d` is the +/// embedding [`Dimension`]. Centroid `i` occupies +/// `centroids[i * d .. (i + 1) * d]`. +pub struct Clustering { + pub dimension: Dimension, + + /// Flat centroid matrix, `k * d` elements in row-major order. + pub centroids: Box<[f32]>, + + /// Cluster assignment for each input point, values in `0..k`. + pub labels: Box<[u16]>, +} + +impl Clustering { + /// Allocates a zeroed clustering for `k` centroids over `n` points. + fn new(k: u16, n: usize, d: Dimension) -> Self { + // SAFETY: All-zero bits are valid for `f32` (IEEE 754 positive zero) + // and for `u16` (the integer 0). `Box::new_zeroed_slice` allocates + // zeroed memory of the correct layout, so `assume_init` is sound. + let centroids: Box<[f32]> = + unsafe { Box::new_zeroed_slice((k as usize) * (d.get() as usize)).assume_init() }; + // SAFETY: All-zero bits are valid for `u16` (the integer 0). + let labels: Box<[u16]> = unsafe { Box::new_zeroed_slice(n).assume_init() }; + + Self { + centroids, + labels, + dimension: d, + } + } + + /// Returns the `D`-dimensional slice for centroid `cluster`. + #[must_use] + pub fn centroid(&self, cluster: u16) -> &[f32] { + &self.centroids[cluster as usize * (self.dimension.get() as usize) + ..(cluster + 1) as usize * (self.dimension.get() as usize)] + } + + /// Returns a mutable `D`-dimensional slice for centroid `cluster`. + fn centroid_mut(&mut self, cluster: u16) -> &mut [f32] { + &mut self.centroids[cluster as usize * (self.dimension.get() as usize) + ..(cluster + 1) as usize * (self.dimension.get() as usize)] + } + + /// Returns the cluster label for point `entity`. + #[must_use] + pub fn label(&self, entity: usize) -> u16 { + self.labels[entity] + } + + /// Returns a mutable reference to the cluster label for point `entity`. + fn label_mut(&mut self, entity: usize) -> &mut u16 { + &mut self.labels[entity] + } +} + +// TODO: I wonder if we can make this allocation less +fn sample_indices(n: usize, m: usize, mut rng: impl Rng) -> Vec { + let mut idx: Vec = (0..n).collect(); + + for i in 0..m { + let j = i + rng.random_range(0..n - i); // partial Fisher–Yates + idx.swap(i, j); + } + + idx.truncate(m); + idx +} + +/// Squared chord distance between a point and a unit centroid. +/// +/// For a unit centroid `c` and a point with inverse norm `inv`, the cosine +/// similarity is `dot(point, c) * inv`. The squared chord distance is +/// `2 - 2 * similarity`, which lies in `[0, 4]` and equals `||u - c||²` +/// when `u` is the unit-normalized point. +/// +/// Returns `0.0` for zero-norm points (`point_inv_norm == 0.0`). +/// +/// This is a squared distance. Do not square it again for D² sampling. +#[inline] +fn squared_chord_distance(dot: f32, point_inv_norm: f32) -> f32 { + if point_inv_norm == 0.0 { + return 0.0; + } + + let similarity = (dot * point_inv_norm).clamp(-1.0, 1.0); + + 2.0_f32.mul_add(-similarity, 2.0).max(0.0) +} + +/// Finds the nearest centroid to `point` and returns its index and spherical +/// distance. +/// +/// # Safety +/// +/// * `point.len() == D` +/// * `centroids.len() == k * D` +/// * `k > 0` +/// * `D` is a multiple of 8 (enforced at compile time by the const generic). +#[inline] +#[must_use] +pub(crate) unsafe fn nearest_centroid( + point: &[f32], + point_inv_norm: f32, + centroids: &[f32], + k: usize, + d: usize, +) -> (u16, f32) { + debug_assert_eq!(point.len(), d); + debug_assert_eq!(centroids.len(), k * d); + debug_assert!(k > 0); + + // SAFETY: the caller guarantees these preconditions. The hints let the + // compiler elide bounds checks on the centroid slicing inside the loop. + unsafe { + core::hint::assert_unchecked(point.len() == d); + core::hint::assert_unchecked(centroids.len() == k * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + core::hint::assert_unchecked(k > 0); + } + + let mut best = 0; + let mut best_dot = f32::NEG_INFINITY; + + for cluster in 0..k { + let start = cluster * d; + let centroid = ¢roids[start..start + d]; + + // SAFETY: `point` and `centroid` both have length `D`, and `D` is a + // multiple of 8 (guaranteed by Dimension). + let dot = unsafe { kernel::dot(point, centroid) }; + + #[expect( + clippy::cast_possible_truncation, + reason = "k is supposed to be low, and checked as such via the config" + )] + if dot > best_dot { + best = cluster as u16; + best_dot = dot; + } + } + + (best, squared_chord_distance(best_dot, point_inv_norm)) +} + +/// Pre-allocated scratch space for the k-means fitting loop. +/// +/// All buffers are allocated once and reused across restarts to avoid +/// per-iteration allocation overhead. +struct Fit { + k: usize, + m: usize, + d: usize, + + /// Current centroids for this restart, `k * d` elements. + centroids: Box<[f32]>, + /// Best centroids seen across all restarts. + best_centroids: Box<[f32]>, + /// Per-cluster accumulator for centroid recomputation, `k * d` elements. + sums: Box<[f32]>, + /// Per-cluster point count for centroid averaging. + counts: Box<[usize]>, + /// Per-sample-point cluster assignment. + labels: Box<[u16]>, + /// Per-sample-point closest centroid distance (for k-means++ seeding). + closest_distances: Box<[f32]>, + /// Tracks which sample points have been selected as seeds. + selected: Box<[bool]>, + /// Lowest inertia across all restarts. + best_inertia: f32, +} + +impl Fit { + fn new(k: usize, m: usize, d: usize) -> Self { + // SAFETY: all-zero bits are valid for f32 (IEEE 754 +0.0), usize (0), u16 (0), and bool + // (false). `Box::new_zeroed_slice` allocates zeroed memory of the correct layout + // for each type, so `assume_init` is sound in every case. + let centroids = unsafe { Box::<[f32]>::new_zeroed_slice(k * d).assume_init() }; + // SAFETY: see above + let best_centroids = unsafe { Box::<[f32]>::new_zeroed_slice(k * d).assume_init() }; + // SAFETY: see above + let sums = unsafe { Box::<[f32]>::new_zeroed_slice(k * d).assume_init() }; + // SAFETY: see above + let counts = unsafe { Box::<[usize]>::new_zeroed_slice(k).assume_init() }; + // SAFETY: see above + let labels = unsafe { Box::<[u16]>::new_zeroed_slice(m).assume_init() }; + // SAFETY: see above + let closest_distances = unsafe { Box::<[f32]>::new_zeroed_slice(m).assume_init() }; + // SAFETY: see above + let selected = unsafe { Box::<[bool]>::new_zeroed_slice(m).assume_init() }; + let best_inertia = f32::INFINITY; + + Self { + k, + m, + d, + centroids, + best_centroids, + sums, + counts, + labels, + closest_distances, + selected, + best_inertia, + } + } + + fn reset_centroids(&mut self) { + self.centroids.fill(0.0); + } + + fn reset_sums(&mut self) { + self.sums.fill(0.0); + } + + fn reset_counts(&mut self) { + self.counts.fill(0); + } + + fn reset_selected(&mut self) { + self.selected.fill(false); + } + + /// Reinitializes empty clusters from the sample point farthest from + /// its assigned centroid. + /// + /// For each empty cluster, scans the sample to find the point with + /// the largest squared chord distance to its current centroid, copies + /// that point as the new centroid (normalized), and updates the + /// point's label so it won't be picked again for subsequent empty + /// clusters in the same pass. + #[expect( + clippy::cast_possible_truncation, + reason = "cluster index < k, and k originates from Config::k (u16)" + )] + fn reinit_empty_clusters(&mut self, sample: &[f32], sample_inv_norms: &[f32]) -> bool { + let &mut Self { d, k, .. } = self; + + let mut reseeded = false; + + for cluster in 0..k { + if self.counts[cluster] != 0 { + continue; + } + + reseeded = true; + let mut farthest_idx = 0; + let mut farthest_dist = -1.0_f32; + + for (i, (point, &inv_norm)) in sample.chunks_exact(d).zip(sample_inv_norms).enumerate() + { + let label = usize::from(self.labels[i]); + let c_start = label * d; + + // SAFETY: point and centroid both have length `d`, + // a multiple of 8 (guaranteed by Dimension). + let dot = unsafe { kernel::dot(point, &self.centroids[c_start..c_start + d]) }; + let dist = squared_chord_distance(dot, inv_norm); + + if dist > farthest_dist { + farthest_dist = dist; + farthest_idx = i; + } + } + + let point_start = farthest_idx * d; + let centroid_start = cluster * d; + self.centroids[centroid_start..centroid_start + d] + .copy_from_slice(&sample[point_start..point_start + d]); + + // SAFETY: centroid row has length `d`, a multiple of 8. + unsafe { + kernel::normalize(&mut self.centroids[centroid_start..centroid_start + d]); + } + + // Update the label so the next empty cluster picks a different + // point (this point's distance to its new centroid is ~0). + self.labels[farthest_idx] = cluster as u16; + } + + reseeded + } + + /// Runs k-means++ initialization followed by Lloyd iterations on the + /// sample, repeating for `n_init` restarts. The best centroids (lowest + /// inertia) are stored in `self.best_centroids`. + fn run( + &mut self, + sample: &[f32], + chunk: usize, + row_chunk: usize, + sample_inv_norms: &[f32], + mut rng: impl Rng, + config: &Config, + ) { + for _ in 0..config.n_init.get() { + self.reset_centroids(); + self.closest_distances.fill(f32::INFINITY); + self.reset_selected(); + + self.seed_plusplus(sample, sample_inv_norms, &mut rng); + + let inertia = self.lloyd(sample, chunk, row_chunk, sample_inv_norms, config); + + if inertia < self.best_inertia { + self.best_inertia = inertia; + mem::swap(&mut self.best_centroids, &mut self.centroids); + } + } + } + + /// Runs Lloyd iterations on the sample until convergence or `max_iters`. + /// Returns the final inertia (sum of distances to assigned centroids). + fn lloyd( + &mut self, + sample: &[f32], + chunk: usize, + row_chunk: usize, + sample_inv_norms: &[f32], + config: &Config, + ) -> f32 { + let &mut Self { d, k, .. } = self; + let mut previous_inertia = f32::INFINITY; + let mut inertia = f32::INFINITY; + + for _ in 0..config.max_iters.get() { + inertia = sample + .par_chunks(row_chunk) + .zip(sample_inv_norms.par_chunks(chunk)) + .zip(self.labels.par_chunks_mut(chunk)) + .map(|((points, inv_norms), labels)| { + let mut inertia = 0.0; + let count = labels.len(); + + // SAFETY: each parallel chunk pairs `count` labels with + // `count * d` floats of point data and `count` inv_norms. + // `d` is a multiple of 8 (guaranteed by Dimension). + unsafe { + core::hint::assert_unchecked(points.len() == count * d); + core::hint::assert_unchecked(inv_norms.len() == count); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut i = 0; + while i + 4 <= count { + let p0 = &points[i * d..i * d + d]; + let p1 = &points[(i + 1) * d..(i + 1) * d + d]; + let p2 = &points[(i + 2) * d..(i + 2) * d + d]; + let p3 = &points[(i + 3) * d..(i + 3) * d + d]; + + // SAFETY: each point length d, centroids length k*d, + // k > 0, d a multiple of 8 (guaranteed by Dimension). + let nearest = + unsafe { kernel::nearest4(p0, p1, p2, p3, &self.centroids, k, d) }; + + let inv = [ + inv_norms[i], + inv_norms[i + 1], + inv_norms[i + 2], + inv_norms[i + 3], + ]; + for m in 0..4 { + labels[i + m] = nearest[m].0; + inertia += squared_chord_distance(nearest[m].1, inv[m]); + } + i += 4; + } + + while i < count { + let point = &points[i * d..i * d + d]; + // SAFETY: point length d, centroids length k*d, k > 0, + // d mult of 8. + let (label, distance) = + unsafe { nearest_centroid(point, inv_norms[i], &self.centroids, k, d) }; + labels[i] = label; + inertia += distance; + i += 1; + } + + inertia + }) + .sum(); + + self.reset_sums(); + self.reset_counts(); + + for ((point, label), inv_norm) in sample + .chunks_exact(d) + .zip(self.labels.iter().copied()) + .zip(sample_inv_norms.iter().copied()) + { + let cluster = usize::from(label); + let start = cluster * d; + + self.counts[cluster] += 1; + + if inv_norm == 0.0 { + continue; + } + + // SAFETY: `sums[start..start + d]` and `point` both have + // length `d`, and `d` is a multiple of 8 (guaranteed by Dimension). + unsafe { + kernel::add_scaled_into(&mut self.sums[start..start + d], point, inv_norm); + } + } + + for cluster in 0..k { + if self.counts[cluster] == 0 { + continue; + } + + let start = cluster * d; + let centroid = &mut self.centroids[start..start + d]; + let sum = &self.sums[start..start + d]; + + // SAFETY: `centroid` and `sum` both have length `D`, and `D` + // is a multiple of 8 (guaranteed by Dimension). + unsafe { + #[expect( + clippy::cast_precision_loss, + reason = "cluster count is bounded by sample_cap (≤8192), well within f32 \ + precision" + )] + let inv_count = 1.0 / self.counts[cluster] as f32; + kernel::scale_into(centroid, sum, inv_count); + } + + // SAFETY: centroid rows have length `D`, and `D` is a + // multiple of 8 (guaranteed by Dimension). + unsafe { + kernel::normalize(centroid); + } + } + + let reseeded = self.reinit_empty_clusters(sample, sample_inv_norms); + + // Skip the convergence check when a cluster was just reseeded: + // the reseeded centroid hasn't had an assignment pass yet, so + // breaking now would waste the reinit. + if !reseeded && previous_inertia.is_finite() { + let relative_change = + (previous_inertia - inertia).abs() / previous_inertia.max(f32::EPSILON); + + if relative_change <= config.tol { + break; + } + } + + previous_inertia = inertia; + } + + inertia + } + + /// k-means++ D² weighted seeding. Picks `k` initial centroids from the + /// sample, each chosen with probability proportional to its squared + /// distance from the nearest already-chosen centroid. + fn seed_plusplus(&mut self, sample: &[f32], sample_inv_norms: &[f32], mut rng: impl Rng) { + let &mut Self { d, k, m, .. } = self; + + let mut restart_rng = Xoshiro256PlusPlus::seed_from_u64(rng.random()); + let mut point = restart_rng.random_range(0..m); + + for cluster in 0..k { + let centroid_start = cluster * d; + let point_start = point * d; + + self.centroids[centroid_start..centroid_start + d] + .copy_from_slice(&sample[point_start..point_start + d]); + + // SAFETY: centroid rows have length `D`, and `D` is a multiple of 8 (guaranteed by + // Dimension). + unsafe { + kernel::normalize(&mut self.centroids[centroid_start..centroid_start + d]); + } + + self.selected[point] = true; + + let centroid = &self.centroids[centroid_start..centroid_start + d]; + + let total: f32 = sample + .par_chunks_exact(d) + .zip(sample_inv_norms.par_iter().copied()) + .zip(self.closest_distances.par_iter_mut()) + .enumerate() + .map(|(index, ((point, inv_norm), closest))| { + if self.selected[index] { + *closest = 0.0; + return 0.0; + } + + // SAFETY: `point` and `centroid` both have length `D`, and + // `D` is a multiple of 8 (guaranteed by Dimension). + let dot = unsafe { kernel::dot(point, centroid) }; + let distance = squared_chord_distance(dot, inv_norm); + + if distance < *closest { + *closest = distance; + } + + *closest + }) + .sum(); + + if cluster + 1 == k { + break; + } + + point = if total.is_finite() && total > 0.0 { + let mut target = restart_rng.random_range(0.0..total); + let mut sampled = self + .closest_distances + .iter() + .rposition(|distance| *distance > 0.0) + .unwrap_or(0); + + for (index, distance) in self.closest_distances.iter().copied().enumerate() { + if distance <= 0.0 { + continue; + } + + target -= distance; + + if target <= 0.0 { + sampled = index; + break; + } + } + + sampled + } else { + let remaining = self.selected.iter().filter(|selected| !**selected).count(); + let mut target = restart_rng.random_range(0..remaining); + let mut sampled = 0; + + for (index, selected) in self.selected.iter().copied().enumerate() { + if selected { + continue; + } + + if target == 0 { + sampled = index; + break; + } + + target -= 1; + } + + sampled + }; + } + } +} + +/// Per-thread accumulator for parallel centroid recomputation. +/// +/// Each rayon task gets its own `Accum`; they are merged via [`Accum::merge`] +/// after the parallel fold completes. +struct Accum { + /// Per-cluster sum of normalized points, `k * d` elements. + sums: Box<[f32]>, + /// Per-cluster point count. + counts: Box<[usize]>, +} + +impl Accum { + fn new(k: usize, d: usize) -> Self { + // SAFETY: all-zero bits are valid for f32 (0.0) and usize (0). `assume_init` is + // sound after `new_zeroed_slice`. + let sums = unsafe { Box::<[f32]>::new_zeroed_slice(k * d).assume_init() }; + // SAFETY: see above + let counts = unsafe { Box::<[usize]>::new_zeroed_slice(k).assume_init() }; + Self { sums, counts } + } + + fn merge(mut self, other: &Self, k: usize, d: usize) -> Self { + for cluster in 0..k { + let start = cluster * d; + + self.counts[cluster] += other.counts[cluster]; + + // SAFETY: both cluster sum rows have length `d`, and `d` is a + // multiple of 8 (guaranteed by Dimension). + unsafe { + kernel::add_into( + &mut self.sums[start..start + d], + &other.sums[start..start + d], + ); + } + } + + self + } +} + +/// Assigns all `n` points to their nearest centroid, recomputes centroids +/// from the full population, and re-assigns labels to the final centroids. +/// +/// Uses a parallel fold/reduce: each rayon task accumulates into its own +/// [`Accum`], then results are merged. The final centroids are averaged +/// and normalized in-place. +/// +/// # Safety +/// +/// * `x.len() == n * D` for some `n` +/// * `clustering.centroids.len() == k * D` +/// * `clustering.labels.len() == n` +/// * `k > 0` +/// * `D` is a multiple of 8 (guaranteed by Dimension) +unsafe fn assign(x: &[f32], clustering: &mut Clustering, k: usize, chunk: usize, row_chunk: usize) { + let d = clustering.dimension.get() as usize; + + let full = x + .par_chunks(row_chunk) + .zip(clustering.labels.par_chunks_mut(chunk)) + .fold( + || Accum::new(k, d), + |mut accum, (points, labels)| { + // SAFETY: `cluster` established `x.len() == n * D` and + // `centroids.len() == k * D`. `par_chunks(row_chunk)` with + // `row_chunk = chunk * D` produces chunks where + // `points.len()` is a multiple of `D` and matches `labels.len() * D`. + unsafe { + assign_chunk(&clustering.centroids, k, d, points, labels, &mut accum); + } + + accum + }, + ) + .reduce(|| Accum::new(k, d), |lhs, rhs| lhs.merge(&rhs, k, d)); + + for cluster in 0..k { + if full.counts[cluster] == 0 { + continue; + } + + let start = cluster * d; + + #[expect( + clippy::cast_possible_truncation, + reason = "cluster < k and k originates from Config::k (u16)" + )] + let centroid = clustering.centroid_mut(cluster as u16); + let sum = &full.sums[start..start + d]; + + // SAFETY: centroid and sum both length D, a multiple of 8. + unsafe { + #[expect( + clippy::cast_precision_loss, + reason = "cluster count bounded by n; precision loss acceptable for averaging" + )] + let inv_count = 1.0 / full.counts[cluster] as f32; + kernel::scale_into(centroid, sum, inv_count); + } + // SAFETY: centroid length D, a multiple of 8. + unsafe { + kernel::normalize(centroid); + } + } + + // SAFETY: centroids were just recomputed; same invariants hold. + unsafe { + reassign( + x, + &clustering.centroids, + &mut clustering.labels, + k, + d, + chunk, + row_chunk, + ); + } +} + +/// Processes one parallel chunk of the assignment step: finds the nearest +/// centroid for each point, accumulates normalized points into cluster sums, +/// and records labels. +/// +/// # Safety +/// +/// * `points.len() == labels.len() * d` +/// * `centroids.len() >= k * d` +/// * `d` is a multiple of 8 +/// * `k > 0` +/// * `accum.sums.len() >= k * d` and `accum.counts.len() >= k` +unsafe fn assign_chunk( + centroids: &[f32], + k: usize, + d: usize, + points: &[f32], + labels: &mut [u16], + accum: &mut Accum, +) { + // field path -> disjoint capture of `centroids` only, leaving + // `labels` free for the mutable parallel borrow. + let count = labels.len(); + + // SAFETY: each parallel chunk pairs `count` labels with + // `count * D` floats of point data. `D` is a compile-time + // multiple of 8. + unsafe { + core::hint::assert_unchecked(points.len() == count * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut i = 0; + while i + 4 <= count { + let p0 = &points[i * d..i * d + d]; + let p1 = &points[(i + 1) * d..(i + 1) * d + d]; + let p2 = &points[(i + 2) * d..(i + 2) * d + d]; + let p3 = &points[(i + 3) * d..(i + 3) * d + d]; + + // SAFETY: each point length D, centroids length k*D, k > 0, D a multiple of 8 (guaranteed + // by Dimension). + let nearest = unsafe { kernel::nearest4(p0, p1, p2, p3, centroids, k, d) }; + let ps = [p0, p1, p2, p3]; + + for m in 0..4 { + let label = nearest[m].0; + labels[i + m] = label; + let cluster = usize::from(label); + accum.counts[cluster] += 1; + + let start = cluster * d; + + // SAFETY: point length D, a multiple of 8. + let norm = unsafe { kernel::dot(ps[m], ps[m]).sqrt() }; + if norm == 0.0 { + continue; + } + + // SAFETY: `sums[start..start + D]` and `point` both have length `D`, and `D` is a + // multiple of 8 (guaranteed by Dimension). + unsafe { + kernel::add_scaled_into(&mut accum.sums[start..start + d], ps[m], norm.recip()); + } + } + i += 4; + } + + while i < count { + let point = &points[i * d..i * d + d]; + // SAFETY: point length D, centroids length k*D, k > 0, D mult of 8. + let (label, _) = unsafe { nearest_centroid(point, 1.0, centroids, k, d) }; + labels[i] = label; + let cluster = usize::from(label); + accum.counts[cluster] += 1; + + let start = cluster * d; + + // SAFETY: point length D. + let norm = unsafe { kernel::dot(point, point).sqrt() }; + if norm != 0.0 { + // SAFETY: `sums[start..start + D]` and `point` both + // have length `D`, and `D` is a multiple of 8. + unsafe { + kernel::add_scaled_into(&mut accum.sums[start..start + d], point, norm.recip()); + } + } + i += 1; + } +} + +/// Processes one parallel chunk of the reassignment step: updates each +/// label to the nearest final centroid. +/// +/// # Safety +/// +/// * `points.len() == labels.len() * d` +/// * `centroids.len() >= k * d` +/// * `d` is a multiple of 8 +/// * `k > 0` +unsafe fn reassign_chunk( + k: usize, + d: usize, + centroids: &[f32], + points: &[f32], + labels: &mut [u16], +) { + let count = labels.len(); + + // SAFETY: each parallel chunk pairs `count` labels with + // `count * D` floats of point data. `D` is a compile-time + // multiple of 8. + unsafe { + core::hint::assert_unchecked(points.len() == count * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut i = 0; + while i + 4 <= count { + let p0 = &points[i * d..i * d + d]; + let p1 = &points[(i + 1) * d..(i + 1) * d + d]; + let p2 = &points[(i + 2) * d..(i + 2) * d + d]; + let p3 = &points[(i + 3) * d..(i + 3) * d + d]; + + // SAFETY: each point length D, centroids length k*D, k > 0, + // D a multiple of 8 (guaranteed by Dimension). + let nearest = unsafe { kernel::nearest4(p0, p1, p2, p3, centroids, k, d) }; + + labels[i] = nearest[0].0; + labels[i + 1] = nearest[1].0; + labels[i + 2] = nearest[2].0; + labels[i + 3] = nearest[3].0; + i += 4; + } + + while i < count { + let point = &points[i * d..i * d + d]; + // SAFETY: point length D, centroids length k*D, k > 0, D mult of 8. + let (label, _) = unsafe { nearest_centroid(point, 1.0, centroids, k, d) }; + labels[i] = label; + i += 1; + } +} + +/// Re-assigns labels to the nearest final centroid. +/// +/// After centroid recomputation, some boundary points may no longer be +/// nearest to the centroid stored under their label. This pass fixes that. +/// +/// # Safety +/// +/// Same as [`assign`]. +unsafe fn reassign( + x: &[f32], + centroids: &[f32], + labels: &mut [u16], + k: usize, + d: usize, + chunk: usize, + row_chunk: usize, +) { + x.par_chunks(row_chunk) + .zip(labels.par_chunks_mut(chunk)) + .for_each(|(points, labels)| { + // SAFETY: `par_chunks(row_chunk)` with `row_chunk = chunk * D` + // ensures `points.len() == labels.len() * D`. Centroids and k + // are valid from the caller. + unsafe { + reassign_chunk(k, d, centroids, points, labels); + } + }); +} + +/// Runs spherical k-means over a flat row-major embedding matrix. +/// +/// `x` contains `n` points of `dimension` floats each, laid out +/// contiguously. Returns cluster assignments and unit-normalized centroids. +/// +/// # Panics +/// +/// Panics if `x.len()` is not a multiple of `dimension`. +#[must_use] +#[expect(clippy::integer_division_remainder_used, clippy::integer_division)] +pub fn cluster(x: &[f32], dimension: Dimension, config: &Config) -> Clustering { + let d = dimension.get() as usize; + assert!(x.len().is_multiple_of(d)); + + let n = x.len() / d; + let k = cmp::min(config.k, n.saturating_truncate()); + + let mut clustering = Clustering::new(k, n, dimension); + + if k == 0 { + return clustering; + } + + let k = usize::from(k); + let mut rng = Xoshiro256PlusPlus::seed_from_u64(config.seed); + + // 1. subsample (fit on all of n only when n is already small) + let m = config.sample_cap.max(k).min(n); + + let sample = if m == n { + Cow::Borrowed(x) + } else { + let indices = sample_indices(n, m, &mut rng); + let mut sampled = vec![0_f32; m * d]; + + let chunks = sampled.chunks_mut(d); + assert_eq!(chunks.len(), indices.len()); + + for (chunk, index) in chunks.zip(indices) { + chunk.copy_from_slice(&x[index * d..(index + 1) * d]); + } + + Cow::Owned(sampled) + }; + + let sample = sample.as_ref(); + let chunk = config.chunk.get(); + let row_chunk = chunk + .checked_mul(d) + .unwrap_or_else(|| usize::MAX - (usize::MAX % d)) + .max(d); + + let sample_inv_norms: Vec = sample + .par_chunks_exact(d) + .map(|point| { + // SAFETY: every point is a `d`-sized row, and `d` is a multiple of 8 (guaranteed by + // Dimension). + let norm = unsafe { kernel::dot(point, point).sqrt() }; + + if norm > 0.0 { norm.recip() } else { 0.0 } + }) + .collect(); + + // 2. fit on the sample, best of n_init restarts (guards against bad initializations) + let mut fit = Fit::new(k, m, d); + fit.run(sample, chunk, row_chunk, &sample_inv_norms, rng, config); + mem::swap(&mut clustering.centroids, &mut fit.best_centroids); + + // 3. assign points to clusters + // SAFETY: `x.len() == n * d` (asserted above), `clustering.centroids.len() == k * d`, + // `k > 0` (checked above), `d` is a multiple of 8 (guaranteed by Dimension). + unsafe { + assign(x, &mut clustering, k, chunk, row_chunk); + } + + clustering +} + +#[cfg(test)] +mod tests { + #![expect( + clippy::float_cmp, + clippy::integer_division_remainder_used, + reason = "test module: float comparisons are intentional for exact-zero and distance \ + checks; modulo is used in test data construction" + )] + use super::*; + + /// Builds well-separated blob clusters in D-dimensional space. + /// + /// Each blob has a dominant axis so clusters are far apart in cosine + /// space. Returns `(flat_points, ground_truth_labels)`. + #[expect( + clippy::cast_possible_truncation, + reason = "k is small in tests, fits in u16" + )] + fn make_blobs( + points_per_cluster: usize, + k: usize, + seed: u64, + ) -> (Vec, Vec) { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + let n = points_per_cluster * k; + let mut data = vec![0.0_f32; n * D]; + let mut truth = vec![0_u16; n]; + + for c in 0..k { + let axis = c % D; + for p in 0..points_per_cluster { + let idx = c * points_per_cluster + p; + let row = &mut data[idx * D..(idx + 1) * D]; + + row[axis] = 10.0; + for val in row.iter_mut() { + *val += rng.random_range(-0.01..0.01); + } + + truth[idx] = c as u16; + } + } + + (data, truth) + } + + const D: usize = 64; + + fn l2(v: &[f32]) -> f32 { + v.iter().map(|x| x * x).sum::().sqrt() + } + + /// Random unit-norm centroids in `D`-dimensional space. + fn unit_random(k: usize, seed: u64) -> Vec { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + let mut c = vec![0.0_f32; k * D]; + for row in c.chunks_exact_mut(D) { + for v in row.iter_mut() { + *v = rng.random_range(-1.0..1.0); + } + let n = l2(row); + for v in row.iter_mut() { + *v /= n; + } + } + c + } + + /// Brute-force nearest centroid by cosine similarity. + #[expect(clippy::cast_possible_truncation, reason = "k is small in tests")] + fn brute_nearest_cosine(point: &[f32], centroids: &[f32], k: usize) -> u16 { + let pn = l2(point); + let mut best = 0_u16; + let mut best_cos = f32::NEG_INFINITY; + for c in 0..k { + let cent = ¢roids[c * D..(c + 1) * D]; + let d: f32 = point.iter().zip(cent).map(|(a, b)| a * b).sum(); + let cn = l2(cent); + let cos = if pn == 0.0 || cn == 0.0 { + 0.0 + } else { + d / (pn * cn) + }; + if cos > best_cos { + best_cos = cos; + best = c as u16; + } + } + best + } + + /// Computes clustering accuracy using majority-vote label mapping. + /// + /// K-means labels are permutation-invariant, so this assigns each + /// predicted cluster to whichever ground-truth cluster it overlaps + /// most, then counts correctly assigned points. + #[expect( + clippy::cast_precision_loss, + reason = "counts are small test values, well within f64 precision" + )] + fn accuracy(predicted: &[u16], truth: &[u16], k: usize) -> f64 { + let mut votes = vec![vec![0_usize; k]; k]; + for (&pred, &true_label) in predicted.iter().zip(truth) { + votes[pred as usize][true_label as usize] += 1; + } + + let correct: usize = votes + .iter() + .map(|row| row.iter().copied().max().unwrap_or(0)) + .sum(); + + correct as f64 / predicted.len() as f64 + } + + /// Shorthand for [`Dimension::new`] that panics on invalid input. + fn dim(d: u16) -> Dimension { + Dimension::new(d).expect("test dimension must be a positive multiple of 8") + } + + #[test] + fn chord_identical_vectors_is_zero() { + // dot=1.0, inv_norm=1.0 => similarity=1 => distance=0 + assert_eq!(squared_chord_distance(1.0, 1.0), 0.0); + } + + #[test] + fn chord_orthogonal_vectors() { + // dot=0 => similarity=0 => distance=2 + let dist = squared_chord_distance(0.0, 1.0); + assert!((dist - 2.0).abs() < 1e-6, "expected 2.0, got {dist}"); + } + + #[test] + fn chord_opposite_vectors() { + // dot=-1.0 => similarity=-1 => distance=4 + let dist = squared_chord_distance(-1.0, 1.0); + assert!((dist - 4.0).abs() < 1e-6, "expected 4.0, got {dist}"); + } + + #[test] + fn chord_zero_norm_returns_zero() { + assert_eq!(squared_chord_distance(0.5, 0.0), 0.0); + assert_eq!(squared_chord_distance(-0.5, 0.0), 0.0); + } + + #[test] + fn chord_is_non_negative() { + for dot_val in [0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0] { + for inv in [0.0, 0.5, 1.0, 2.0] { + let dist = squared_chord_distance(dot_val, inv); + assert!(dist >= 0.0, "negative for dot={dot_val}, inv={inv}: {dist}"); + } + } + } + + #[test] + fn cluster_empty_input() { + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&[], dim(8), &config); + assert_eq!(result.labels.len(), 0); + assert_eq!(result.centroids.len(), 0); + } + + #[test] + fn cluster_k0() { + let data = vec![1.0_f32; 8]; + let config = Config::for_k_with_seed(0, 42); + let result = cluster(&data, dim(8), &config); + assert_eq!(result.labels.len(), 1); + assert_eq!(result.labels[0], 0); + } + + #[test] + fn cluster_k1_all_same_label() { + let (data, _) = make_blobs::<8>(20, 3, 123); + let config = Config::for_k_with_seed(1, 42); + let result = cluster(&data, dim(8), &config); + + assert_eq!(result.labels.len(), 60); + assert!( + result.labels.iter().all(|&l| l == 0), + "k=1: all labels must be 0" + ); + } + + #[test] + fn cluster_single_point() { + let data = vec![1.0_f32; 16]; + let config = Config::for_k_with_seed(5, 42); + // k clamped to min(k, n) = 1 + let result = cluster(&data, dim(16), &config); + assert_eq!(result.labels.len(), 1); + assert_eq!(result.labels[0], 0); + } + + #[test] + fn cluster_n_less_than_4() { + // n=3 exercises the scalar tail (no nearest4 tiling). + let (data, _) = make_blobs::<8>(1, 3, 99); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(8), &config); + + assert_eq!(result.labels.len(), 3); + let mut seen = [false; 3]; + for &label in &*result.labels { + seen[label as usize] = true; + } + assert!( + seen.iter().all(|&s| s), + "each point should have a unique cluster" + ); + } + + #[test] + fn cluster_n_equals_k() { + let (data, _) = make_blobs::<8>(1, 5, 77); + let config = Config::for_k_with_seed(5, 42); + let result = cluster(&data, dim(8), &config); + + assert_eq!(result.labels.len(), 5); + let mut seen = [false; 5]; + for &label in &*result.labels { + seen[label as usize] = true; + } + assert!( + seen.iter().all(|&s| s), + "n=k: each point should be its own cluster" + ); + } + + #[test] + fn cluster_recovers_well_separated_blobs() { + let (data, truth) = make_blobs::<8>(50, 4, 314); + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&data, dim(8), &config); + + let acc = accuracy(&result.labels, &truth, 4); + assert!( + acc > 0.95, + "expected >95% accuracy on well-separated blobs, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_deterministic_with_same_seed() { + let (data, _) = make_blobs::<8>(30, 3, 555); + + let r1 = cluster(&data, dim(8), &Config::for_k_with_seed(3, 42)); + let r2 = cluster(&data, dim(8), &Config::for_k_with_seed(3, 42)); + + assert_eq!(r1.labels, r2.labels); + assert_eq!(r1.centroids, r2.centroids); + } + + #[test] + fn cluster_different_seeds_may_differ() { + let (data, _) = make_blobs::<8>(30, 3, 555); + + let r1 = cluster(&data, dim(8), &Config::for_k_with_seed(3, 42)); + let r2 = cluster(&data, dim(8), &Config::for_k_with_seed(3, 9999)); + + // Not guaranteed to differ, but with well-separated blobs and + // different seeds the label permutation usually differs. + assert!( + r1.labels != r2.labels, + "different seeds produced identical label vectors (possible but unlikely)" + ); + } + + #[test] + fn cluster_centroids_are_unit_normalized() { + let (data, _) = make_blobs::<8>(40, 4, 222); + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&data, dim(8), &config); + + for c in 0..4_u16 { + let centroid = result.centroid(c); + // SAFETY: centroid has length 8 (= D), a multiple of 8. + let norm = unsafe { kernel::dot(centroid, centroid).sqrt() }; + assert!( + (norm - 1.0).abs() < 1e-5, + "centroid {c} has norm {norm}, expected 1.0" + ); + } + } + + #[test] + fn cluster_labels_in_range() { + let (data, _) = make_blobs::<8>(25, 5, 333); + let config = Config::for_k_with_seed(5, 42); + let result = cluster(&data, dim(8), &config); + + for (i, &label) in result.labels.iter().enumerate() { + assert!(label < 5, "label[{i}] = {label}, expected < 5"); + } + } + + #[test] + fn cluster_labels_nearest_to_assigned_centroid() { + let (data, _) = make_blobs::<8>(30, 3, 444); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(8), &config); + + let k = 3_usize; + let d = 8_usize; + for (i, point) in data.chunks_exact(d).enumerate() { + let assigned = result.labels[i]; + // SAFETY: point and centroid both have length 8 (= D), a multiple of 8. + let assigned_dot = unsafe { kernel::dot(point, result.centroid(assigned)) }; + + #[expect(clippy::cast_possible_truncation, reason = "k=3 fits in u16")] + for c in 0..k as u16 { + // SAFETY: point and centroid both have length 8, a multiple of 8. + let other_dot = unsafe { kernel::dot(point, result.centroid(c)) }; + assert!( + other_dot <= assigned_dot + 1e-5, + "point {i}: assigned to {assigned} (dot={assigned_dot}) but centroid {c} has \ + higher dot={other_dot}" + ); + } + } + } + + #[test] + fn cluster_d32_recovers_blobs() { + let (data, truth) = make_blobs::<32>(40, 3, 888); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(32), &config); + + let acc = accuracy(&result.labels, &truth, 3); + assert!( + acc > 0.95, + "D=32: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_recovers_with_subsampling() { + // n=12000 with sample_cap=1024 exercises the Cow::Owned path. + let (data, truth) = make_blobs::<8>(2000, 6, 21); + let mut config = Config::for_k_with_seed(6, 5); + config.sample_cap = 1024; + let result = cluster(&data, dim(8), &config); + + let acc = accuracy(&result.labels, &truth, 6); + assert!( + acc > 0.95, + "subsampled: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_more_clusters_than_natural_groups() { + // 3 natural groups but k=8: empty clusters keep their seed centroid, + // nothing should be NaN or infinite. + let (data, _) = make_blobs::<8>(400, 3, 31); + let result = cluster(&data, dim(8), &Config::for_k_with_seed(8, 1)); + + assert!( + result.centroids.iter().all(|v| v.is_finite()), + "NaN or infinite centroid" + ); + assert!(result.labels.iter().all(|&l| l < 8)); + } + + #[test] + fn cluster_all_identical_points() { + // Every point identical: D² distances are all zero during seeding, + // which triggers the uniform fallback path. + let n = 100; + let mut data = vec![0.0_f32; n * 8]; + for row in data.chunks_exact_mut(8) { + row[0] = 1.0; + } + let result = cluster(&data, dim(8), &Config::for_k_with_seed(4, 1)); + + assert!(result.centroids.iter().all(|v| v.is_finite())); + assert!(result.labels.iter().all(|&l| l < 4)); + } + + #[test] + fn nearest_centroid_matches_brute_force_cosine() { + let k = 7; + let centroids = unit_random(k, 99); + let mut rng = Xoshiro256PlusPlus::seed_from_u64(100); + + for _ in 0..1000 { + let p: Vec = core::iter::repeat_with(|| rng.random_range(-3.0..3.0)) + .take(D) + .collect(); + let pn = l2(&p); + let inv = if pn > 0.0 { pn.recip() } else { 0.0 }; + + // SAFETY: point has length D=64, centroids has length k*D, + // k > 0, D is a multiple of 8. + let (got, _) = unsafe { nearest_centroid(&p, inv, ¢roids, k, D) }; + assert_eq!( + got, + brute_nearest_cosine(&p, ¢roids, k), + "mismatch for point norm={pn}" + ); + } + } + + #[test] + fn nearest_centroid_argmax_independent_of_inv_norm() { + let k = 5; + let centroids = unit_random(k, 7); + let mut rng = Xoshiro256PlusPlus::seed_from_u64(8); + + for _ in 0..500 { + let p: Vec = core::iter::repeat_with(|| rng.random_range(-2.0..2.0)) + .take(D) + .collect(); + + // SAFETY: point has length D=64, centroids has length k*D, + // k > 0, D is a multiple of 8. + let (a, _) = unsafe { nearest_centroid(&p, 1.0, ¢roids, k, D) }; + // SAFETY: same preconditions. + let (b, _) = unsafe { nearest_centroid(&p, 0.123, ¢roids, k, D) }; + assert_eq!(a, b, "inv_norm must not change the selected centroid"); + } + } + + #[test] + fn cluster_mixed_zero_norm_rows() { + // Some all-zero rows exercise the inv_norm == 0 path in accumulation + // and the squared_chord_distance == 0 return. + let n = 120; + let mut data = vec![0.0_f32; n * 8]; + let mut rng = Xoshiro256PlusPlus::seed_from_u64(7); + for (i, row) in data.chunks_exact_mut(8).enumerate() { + if i % 10 == 0 { + continue; // leave all-zero + } + for v in row.iter_mut() { + *v = rng.random_range(-1.0..1.0); + } + } + let result = cluster(&data, dim(8), &Config::for_k_with_seed(5, 1)); + + assert!(result.centroids.iter().all(|v| v.is_finite())); + assert!(result.labels.iter().all(|&l| l < 5)); + } +} diff --git a/libs/@local/graph/store/src/embedding/dimension.rs b/libs/@local/graph/store/src/embedding/dimension.rs new file mode 100644 index 00000000000..a3a1a31fe94 --- /dev/null +++ b/libs/@local/graph/store/src/embedding/dimension.rs @@ -0,0 +1,79 @@ +use core::num::NonZero; + +/// An embedding vector dimension, guaranteed to be a positive multiple of 8. +/// +/// The multiple-of-8 invariant ensures that the dimension evenly divides into +/// SIMD lanes (8×f32 = `f32x8`), so vectorized kernels can operate without +/// remainder handling. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Dimension(NonZero); + +impl Dimension { + /// Creates a new dimension if `value` is non-zero and a multiple of 8. + /// + /// Returns [`None`] otherwise. + #[must_use] + pub const fn new(value: u16) -> Option { + // not using `?` here because it isn't `const` + let Some(value) = NonZero::new(value) else { + return None; + }; + + if !value.get().is_multiple_of(8) { + return None; + } + + Some(Self(value)) + } + + /// The raw dimension value. + #[must_use] + pub const fn get(self) -> u16 { + self.0.get() + } +} + +pub const D128: Dimension = Dimension(NonZero::new(128).unwrap()); +pub const D256: Dimension = Dimension(NonZero::new(256).unwrap()); +pub const D512: Dimension = Dimension(NonZero::new(512).unwrap()); +pub const D1536: Dimension = Dimension(NonZero::new(1536).unwrap()); +pub const D3072: Dimension = Dimension(NonZero::new(3072).unwrap()); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_multiples_of_8() { + for v in [8, 16, 24, 128, 256, 3072] { + assert!( + Dimension::new(v).is_some(), + "{v} should be a valid dimension" + ); + } + } + + #[test] + fn zero_rejected() { + assert!(Dimension::new(0).is_none()); + } + + #[test] + fn non_multiples_of_8_rejected() { + for v in [1, 2, 3, 4, 5, 6, 7, 9, 10, 15, 17, 100, 3071] { + assert!( + Dimension::new(v).is_none(), + "{v} should not be a valid dimension" + ); + } + } + + #[test] + fn constants_have_correct_values() { + assert_eq!(D128.0.get(), 128); + assert_eq!(D256.0.get(), 256); + assert_eq!(D512.0.get(), 512); + assert_eq!(D1536.0.get(), 1536); + assert_eq!(D3072.0.get(), 3072); + } +} diff --git a/libs/@local/graph/store/src/embedding/kernel.rs b/libs/@local/graph/store/src/embedding/kernel.rs new file mode 100644 index 00000000000..0090f63c4c5 --- /dev/null +++ b/libs/@local/graph/store/src/embedding/kernel.rs @@ -0,0 +1,769 @@ +use core::simd::{Simd, f32x8, num::SimdFloat as _}; + +/// Fused multiply-add when the target has native FMA, separate mul+add otherwise. +/// +/// On aarch64, FMA is part of the base NEON instruction set (`fmla`). +/// On x86_64, FMA requires the `fma` target feature (`vfmadd`); without it, +/// `StdFloat::mul_add` falls back to a per-lane `fmaf` libc call which +/// destroys throughput. The non-FMA path uses a plain multiply and add +/// (`vmulps` + `vaddps`) instead. +#[inline(always)] +#[cfg(not(any(target_arch = "aarch64", target_feature = "fma")))] +fn simd_mul_add(lhs: f32x8, rhs: f32x8, acc: f32x8) -> f32x8 { + lhs * rhs + acc +} + +/// See non-FMA variant above for rationale. +#[inline(always)] +#[cfg(any(target_arch = "aarch64", target_feature = "fma"))] +fn simd_mul_add(lhs: f32x8, rhs: f32x8, acc: f32x8) -> f32x8 { + use std::simd::StdFloat as _; + + lhs.mul_add(rhs, acc) +} + +/// Computes the dot product of two equal-length `f32` slices using SIMD. +/// +/// Four independent accumulators are interleaved to saturate FMA throughput: +/// each accumulator feeds a separate dependency chain, hiding the 4-cycle +/// latency of `fmla`/`vfmadd` on typical micro-architectures. +/// +/// # Safety +/// +/// * `lhs.len() == rhs.len()` +/// * Both lengths are multiples of 8. +#[inline] +#[must_use] +pub(crate) unsafe fn dot(lhs: &[f32], rhs: &[f32]) -> f32 { + debug_assert!(lhs.len().is_multiple_of(8) && lhs.len() == rhs.len()); + + // SAFETY: the caller guarantees equal lengths and a multiple of 8. + // These hints let the compiler elide bounds checks in `as_chunks` and + // the subsequent indexing without raw pointer arithmetic. + unsafe { + core::hint::assert_unchecked(lhs.len() == rhs.len()); + core::hint::assert_unchecked(lhs.len().is_multiple_of(8)); + } + + let (lhs, _) = lhs.as_chunks::<8>(); + let (rhs, _) = rhs.as_chunks::<8>(); + + // SAFETY: both original slices have the same length and that length is a + // multiple of 8, so `as_chunks::<8>` produces equal-length chunk slices + // with empty remainders. + unsafe { + core::hint::assert_unchecked(lhs.len() == rhs.len()); + } + + let mut s0 = f32x8::splat(0.0); + let mut s1 = f32x8::splat(0.0); + let mut s2 = f32x8::splat(0.0); + let mut s3 = f32x8::splat(0.0); + + // Unrolled loop: process 4 chunks (32 floats) per iteration. + let mut offset = 0; + while offset + 4 <= lhs.len() { + let Some([l0, l1, l2, l3]) = lhs[offset..offset + 4].as_array() else { + unreachable!() + }; + let Some([r0, r1, r2, r3]) = rhs[offset..offset + 4].as_array() else { + unreachable!() + }; + + s0 = simd_mul_add(Simd::from_slice(l0), Simd::from_slice(r0), s0); + s1 = simd_mul_add(Simd::from_slice(l1), Simd::from_slice(r1), s1); + s2 = simd_mul_add(Simd::from_slice(l2), Simd::from_slice(r2), s2); + s3 = simd_mul_add(Simd::from_slice(l3), Simd::from_slice(r3), s3); + + offset += 4; + } + + // Tail: process remaining 0..3 chunks one at a time. + #[expect(clippy::min_ident_chars)] + while offset < lhs.len() { + let l = &lhs[offset]; + let r = &rhs[offset]; + + s0 = simd_mul_add(Simd::from_slice(l), Simd::from_slice(r), s0); + offset += 1; + } + + (s0 + s1 + s2 + s3).reduce_sum() +} + +/// Adds `src` element-wise into `dst`. +/// +/// # Safety +/// +/// * `dst.len() == src.len()` +/// * Both lengths are multiples of 8. +#[inline] +pub(crate) unsafe fn add_into(dst: &mut [f32], src: &[f32]) { + debug_assert!(dst.len().is_multiple_of(8) && dst.len() == src.len()); + + // SAFETY: the caller guarantees equal lengths and a multiple of 8. + unsafe { + core::hint::assert_unchecked(dst.len() == src.len()); + core::hint::assert_unchecked(dst.len().is_multiple_of(8)); + } + + let (dst, _) = dst.as_chunks_mut::<8>(); + let (src, _) = src.as_chunks::<8>(); + + // SAFETY: same reasoning as the pre-chunk hints: equal input lengths + // that are multiples of 8 produce equal chunk counts. + unsafe { core::hint::assert_unchecked(dst.len() == src.len()) } + + for index in 0..dst.len() { + dst[index] = (f32x8::from_slice(&dst[index]) + f32x8::from_slice(&src[index])).to_array(); + } +} + +/// Writes `src * factor` element-wise into `dst`. +/// +/// # Safety +/// +/// * `dst.len() == src.len()` +/// * Both lengths are multiples of 8. +#[inline] +pub(crate) unsafe fn scale_into(dst: &mut [f32], src: &[f32], factor: f32) { + debug_assert!(dst.len().is_multiple_of(8) && dst.len() == src.len()); + + // SAFETY: the caller guarantees equal lengths and a multiple of 8. + unsafe { + core::hint::assert_unchecked(dst.len() == src.len()); + core::hint::assert_unchecked(dst.len().is_multiple_of(8)); + } + + let factor = f32x8::splat(factor); + let (dst, _) = dst.as_chunks_mut::<8>(); + let (src, _) = src.as_chunks::<8>(); + + // SAFETY: same reasoning as the pre-chunk hints: equal input lengths + // that are multiples of 8 produce equal chunk counts. + unsafe { core::hint::assert_unchecked(dst.len() == src.len()) } + + for index in 0..dst.len() { + dst[index] = (f32x8::from_slice(&src[index]) * factor).to_array(); + } +} + +/// Scales `value` in-place by `factor`. +/// +/// # Safety +/// +/// * `value.len()` is a multiple of 8. +#[inline] +pub(crate) unsafe fn scale(value: &mut [f32], factor: f32) { + debug_assert!(value.len().is_multiple_of(8)); + + // SAFETY: the caller guarantees a multiple of 8. + unsafe { + core::hint::assert_unchecked(value.len().is_multiple_of(8)); + } + + let factor = f32x8::splat(factor); + let (dst, _) = value.as_chunks_mut::<8>(); + + for dst in dst { + *dst = (f32x8::from_slice(dst) * factor).to_array(); + } +} + +/// Accumulates `src * factor` element-wise into `dst` (`dst += src * factor`). +/// +/// Fuses a scale and add into a single pass, using FMA where available. +/// Avoids the need for a scratch buffer when accumulating normalized vectors. +/// +/// # Safety +/// +/// * `dst.len() == src.len()` +/// * Both lengths are multiples of 8. +#[inline] +pub(crate) unsafe fn add_scaled_into(dst: &mut [f32], src: &[f32], factor: f32) { + debug_assert!(dst.len().is_multiple_of(8) && dst.len() == src.len()); + + // SAFETY: the caller guarantees equal lengths and a multiple of 8. + unsafe { + core::hint::assert_unchecked(dst.len() == src.len()); + core::hint::assert_unchecked(dst.len().is_multiple_of(8)); + } + + let factor = f32x8::splat(factor); + let (dst, _) = dst.as_chunks_mut::<8>(); + let (src, _) = src.as_chunks::<8>(); + + // SAFETY: same reasoning as the pre-chunk hints: equal input lengths + // that are multiples of 8 produce equal chunk counts. + unsafe { core::hint::assert_unchecked(dst.len() == src.len()) } + + for index in 0..dst.len() { + let acc = f32x8::from_slice(&dst[index]); + let val = f32x8::from_slice(&src[index]); + dst[index] = simd_mul_add(val, factor, acc).to_array(); + } +} + +/// Normalizes `value` to unit length in-place. +/// +/// If the vector has zero norm, it is left unchanged. +/// +/// # Safety +/// +/// * `value.len()` is a multiple of 8. +#[inline] +pub(crate) unsafe fn normalize(value: &mut [f32]) { + // SAFETY: `dot` requires equal lengths (trivially true, same slice) + // and a multiple of 8 (guaranteed by the caller). + let norm = unsafe { dot(value, value).sqrt() }; + + if norm > 0.0 { + let factor = 1.0 / norm; + // SAFETY: same slice, same length guarantee. + unsafe { + scale(value, factor); + } + } +} + +/// 4 points x 2 centroids. Eight independent accumulators give ILP 8 (enough to +/// saturate FMA throughput); each point chunk feeds 2 FMAs and each centroid +/// chunk feeds 4. Returns `dot[point][centroid]`. +/// +/// Register budget: 8 `f32x8` accumulators. On AVX2 (16 ymm) this leaves room +/// for the 6 operand loads. On NEON each `f32x8` is two 128-bit regs, so the 8 +/// accumulators take 16 of 32 registers; a 4x4 tile (16 accumulators) also fits +/// there if you want more centroid reuse. Either way, check the asm shows the +/// accumulators staying in registers with no stack spills, and that +/// `simd_mul_add` lowered to `vfmadd`/`fmla` and not a `fmaf` call. If the array +/// form ever spills, the manual unroll below is what keeps them in registers. +/// +/// # Safety +/// * all six slices have length `d` +/// * `d` is a multiple of 8 +#[expect( + clippy::inline_always, + reason = "micro-kernel must inline into nearest4 to keep accumulators in registers" +)] +#[inline(always)] +pub(crate) unsafe fn micro_4x2( + p0: &[f32], + p1: &[f32], + p2: &[f32], + p3: &[f32], + c0: &[f32], + c1: &[f32], +) -> [[f32; 2]; 4] { + debug_assert!(p0.len().is_multiple_of(8)); + debug_assert!( + [p1.len(), p2.len(), p3.len(), c0.len(), c1.len()] + .iter() + .all(|&l| l == p0.len()) + ); + + let (p0, _) = p0.as_chunks::<8>(); + let (p1, _) = p1.as_chunks::<8>(); + let (p2, _) = p2.as_chunks::<8>(); + let (p3, _) = p3.as_chunks::<8>(); + let (c0, _) = c0.as_chunks::<8>(); + let (c1, _) = c1.as_chunks::<8>(); + + // SAFETY: the caller guarantees all six slices have equal length `d`, + // and `d` is a multiple of 8. The hints let the compiler prove that + // `as_chunks` produces equal-length chunk slices. + unsafe { + core::hint::assert_unchecked(p0.len() == p1.len()); + core::hint::assert_unchecked(p0.len() == p2.len()); + core::hint::assert_unchecked(p0.len() == p3.len()); + core::hint::assert_unchecked(p0.len() == c0.len()); + core::hint::assert_unchecked(p0.len() == c1.len()); + } + + let mut a00 = f32x8::splat(0.0); + let mut a01 = f32x8::splat(0.0); + let mut a10 = f32x8::splat(0.0); + let mut a11 = f32x8::splat(0.0); + let mut a20 = f32x8::splat(0.0); + let mut a21 = f32x8::splat(0.0); + let mut a30 = f32x8::splat(0.0); + let mut a31 = f32x8::splat(0.0); + + for t in 0..c0.len() { + let v0 = Simd::from_array(c0[t]); + let v1 = Simd::from_array(c1[t]); + let x0 = Simd::from_array(p0[t]); + let x1 = Simd::from_array(p1[t]); + let x2 = Simd::from_array(p2[t]); + let x3 = Simd::from_array(p3[t]); + + // super::simd_mul_add picks the FMA arm per target. + a00 = simd_mul_add(x0, v0, a00); + a01 = simd_mul_add(x0, v1, a01); + a10 = simd_mul_add(x1, v0, a10); + a11 = simd_mul_add(x1, v1, a11); + a20 = simd_mul_add(x2, v0, a20); + a21 = simd_mul_add(x2, v1, a21); + a30 = simd_mul_add(x3, v0, a30); + a31 = simd_mul_add(x3, v1, a31); + } + + [ + [a00.reduce_sum(), a01.reduce_sum()], + [a10.reduce_sum(), a11.reduce_sum()], + [a20.reduce_sum(), a21.reduce_sum()], + [a30.reduce_sum(), a31.reduce_sum()], + ] +} + +/// Finds the nearest centroid for 4 points simultaneously using the +/// [`micro_4x2`] tiled kernel. +/// +/// Returns `(centroid_index, raw_dot_product)` for each of the 4 points. +/// The raw dot product is **not** a distance; the caller must convert via +/// [`squared_chord_distance`](super::clustering::squared_chord_distance) +/// if needed. +/// +/// # Safety +/// +/// * All four point slices have length `d`. +/// * `centroids.len() >= k * d`. +/// * `d` is a multiple of 8. +/// * `k > 0`. +#[inline] +#[must_use] +pub(crate) unsafe fn nearest4( + p0: &[f32], + p1: &[f32], + p2: &[f32], + p3: &[f32], + centroids: &[f32], + k: usize, + d: usize, +) -> [(u16, f32); 4] { + let mut best_dot = [f32::NEG_INFINITY; 4]; + let mut best_idx = [0_u16; 4]; + + // SAFETY: the caller guarantees these preconditions. + unsafe { + core::hint::assert_unchecked(p0.len() == d); + core::hint::assert_unchecked(p0.len() == p1.len()); + core::hint::assert_unchecked(p0.len() == p2.len()); + core::hint::assert_unchecked(p0.len() == p3.len()); + core::hint::assert_unchecked(centroids.len() >= k * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + core::hint::assert_unchecked(k > 0); + } + + let mut j = 0; + while j + 2 <= k { + // SAFETY: `j + 2 <= k` and `centroids.len() >= k * d`, so both + // slices `[j*d .. (j+2)*d]` are in-bounds. + let c0 = unsafe { centroids.get_unchecked(j * d..j * d + d) }; + // SAFETY: see above. + let c1 = unsafe { centroids.get_unchecked((j + 1) * d..(j + 1) * d + d) }; + + // SAFETY: all six slices have length `d`, a multiple of 8. + let dots = unsafe { micro_4x2(p0, p1, p2, p3, c0, c1) }; + + #[expect( + clippy::cast_possible_truncation, + reason = "k originates from Config::k (u16), so j < k fits in u16" + )] + for m in 0..4 { + if dots[m][0] > best_dot[m] { + best_dot[m] = dots[m][0]; + best_idx[m] = j as u16; + } + if dots[m][1] > best_dot[m] { + best_dot[m] = dots[m][1]; + best_idx[m] = (j + 1) as u16; + } + } + j += 2; + } + + // Handle odd k: one remaining centroid. + if j < k { + let c = ¢roids[j * d..j * d + d]; + let ps = [p0, p1, p2, p3]; + for m in 0..4 { + // SAFETY: point and centroid both have length `d`, a multiple of 8. + let d = unsafe { dot(ps[m], c) }; + #[expect( + clippy::cast_possible_truncation, + reason = "k originates from Config::k (u16)" + )] + if d > best_dot[m] { + best_dot[m] = d; + best_idx[m] = j as u16; + } + } + } + + [ + (best_idx[0], best_dot[0]), + (best_idx[1], best_dot[1]), + (best_idx[2], best_dot[2]), + (best_idx[3], best_dot[3]), + ] +} + +#[cfg(test)] +mod tests { + #![expect(clippy::float_cmp, clippy::integer_division_remainder_used)] + + use super::*; + + /// Scalar dot product for reference. + fn ref_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(x, y)| x * y).sum() + } + + /// Scalar normalize for reference. + fn ref_normalize(v: &mut [f32]) { + let norm = v.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for x in v { + *x /= norm; + } + } + } + + /// Deterministic test vector: entry `i` gets `(i+1) * scale`. + #[expect(clippy::cast_precision_loss)] + fn ramp(len: usize, factor: f32) -> Vec { + (0..len).map(|i| (i + 1) as f32 * factor).collect() + } + + /// Asserts two f32 values are within relative tolerance, with an absolute + /// floor for values near zero. + fn assert_close(a: f32, b: f32, tol: f32) { + let diff = (a - b).abs(); + let denom = a.abs().max(b.abs()).max(1e-12); + assert!( + diff / denom < tol, + "values differ: {a} vs {b} (diff={diff}, rel={})", + diff / denom + ); + } + + #[test] + fn dot_matches_scalar_d8() { + let a = ramp(8, 1.0); + let b = ramp(8, 0.5); + let expected = ref_dot(&a, &b); + // SAFETY: both slices have length 8, a multiple of 8. + let got = unsafe { dot(&a, &b) }; + assert_close(got, expected, 1e-6); + } + + #[test] + fn dot_matches_scalar_d24() { + // 24 = 3 chunks of 8: the 4-unrolled body runs 0 iterations, + // all 3 chunks go through the tail path. + let a = ramp(24, 0.1); + let b = ramp(24, -0.2); + let expected = ref_dot(&a, &b); + // SAFETY: both slices have length 24, a multiple of 8. + let got = unsafe { dot(&a, &b) }; + assert_close(got, expected, 1e-5); + } + + #[test] + fn dot_matches_scalar_d3072() { + let a = ramp(3072, 0.001); + let b = ramp(3072, -0.002); + let expected = ref_dot(&a, &b); + // SAFETY: both slices have length 3072, a multiple of 8. + let got = unsafe { dot(&a, &b) }; + assert_close(got, expected, 1e-4); + } + + #[test] + fn dot_is_commutative() { + let a = ramp(32, 0.3); + let b = ramp(32, -0.7); + // SAFETY: both slices have length 32, a multiple of 8. + let ab = unsafe { dot(&a, &b) }; + // SAFETY: same slices, reversed. + let ba = unsafe { dot(&b, &a) }; + assert_eq!(ab, ba); + } + + #[test] + fn dot_self_is_squared_norm() { + let a = ramp(16, 0.5); + let expected: f32 = a.iter().map(|x| x * x).sum(); + // SAFETY: both arguments are the same 16-element slice. + let got = unsafe { dot(&a, &a) }; + assert_close(got, expected, 1e-6); + } + + #[test] + fn dot_orthogonal_is_zero() { + let mut a = vec![0.0_f32; 8]; + let mut b = vec![0.0_f32; 8]; + a[0] = 1.0; + b[1] = 1.0; + // SAFETY: both slices have length 8. + let got = unsafe { dot(&a, &b) }; + assert_eq!(got, 0.0); + } + + #[test] + fn add_into_matches_scalar() { + let src = ramp(16, 1.0); + let mut dst = ramp(16, 0.5); + let expected: Vec = dst.iter().zip(&src).map(|(d, s)| d + s).collect(); + // SAFETY: both slices have length 16, a multiple of 8. + unsafe { add_into(&mut dst, &src) } + assert_eq!(dst, expected); + } + + #[test] + fn add_into_zero_is_identity() { + let zeros = vec![0.0_f32; 24]; + let mut dst = ramp(24, 1.0); + let original = dst.clone(); + // SAFETY: both slices have length 24, a multiple of 8. + unsafe { add_into(&mut dst, &zeros) } + assert_eq!(dst, original); + } + + #[test] + fn scale_into_matches_scalar() { + let src = ramp(16, 1.0); + let mut dst = vec![0.0_f32; 16]; + let factor = 2.5; + let expected: Vec = src.iter().map(|x| x * factor).collect(); + // SAFETY: both slices have length 16, a multiple of 8. + unsafe { scale_into(&mut dst, &src, factor) } + assert_eq!(dst, expected); + } + + #[test] + fn scale_into_zero_gives_zeros() { + let src = ramp(8, 1.0); + let mut dst = ramp(8, 999.0); + // SAFETY: both slices have length 8, a multiple of 8. + unsafe { scale_into(&mut dst, &src, 0.0) } + assert!(dst.iter().all(|&x| x == 0.0)); + } + + #[test] + fn scale_into_one_is_copy() { + let src = ramp(16, 0.3); + let mut dst = vec![0.0_f32; 16]; + // SAFETY: both slices have length 16, a multiple of 8. + unsafe { scale_into(&mut dst, &src, 1.0) } + assert_eq!(dst, src); + } + + #[test] + fn scale_matches_scalar() { + let mut v = ramp(16, 1.0); + let factor = -0.5; + let expected: Vec = v.iter().map(|x| x * factor).collect(); + // SAFETY: slice has length 16, a multiple of 8. + unsafe { scale(&mut v, factor) } + assert_eq!(v, expected); + } + + #[test] + fn add_scaled_into_matches_separate_ops() { + let src = ramp(16, 1.0); + let factor = 0.3; + let mut dst = ramp(16, 0.5); + let expected: Vec = dst.iter().zip(&src).map(|(d, s)| d + s * factor).collect(); + + // SAFETY: both slices have length 16, a multiple of 8. + unsafe { add_scaled_into(&mut dst, &src, factor) } + + for (&got, &exp) in dst.iter().zip(&expected) { + assert_close(got, exp, 1e-6); + } + } + + #[test] + fn add_scaled_into_factor_zero_is_identity() { + let src = ramp(8, 100.0); + let mut dst = ramp(8, 1.0); + let original = dst.clone(); + // SAFETY: both slices have length 8, a multiple of 8. + unsafe { add_scaled_into(&mut dst, &src, 0.0) } + assert_eq!(dst, original); + } + + #[test] + fn normalize_produces_unit_norm() { + let mut v = ramp(32, 0.7); + // SAFETY: length 32, a multiple of 8. + unsafe { normalize(&mut v) } + // SAFETY: same slice, length unchanged. + let norm = unsafe { dot(&v, &v).sqrt() }; + assert_close(norm, 1.0, 1e-6); + } + + #[test] + fn normalize_preserves_direction() { + let mut v = ramp(16, 2.0); + let mut ref_v = v.clone(); + ref_normalize(&mut ref_v); + // SAFETY: length 16, a multiple of 8. + unsafe { normalize(&mut v) } + for (&a, &b) in v.iter().zip(&ref_v) { + assert_close(a, b, 1e-6); + } + } + + #[test] + fn normalize_zero_vector_unchanged() { + let mut v = vec![0.0_f32; 8]; + // SAFETY: length 8, a multiple of 8. + unsafe { normalize(&mut v) } + assert!(v.iter().all(|&x| x == 0.0)); + } + + #[test] + fn normalize_already_unit_is_stable() { + let mut v = vec![0.0_f32; 8]; + v[0] = 1.0; + // SAFETY: length 8, a multiple of 8. + unsafe { normalize(&mut v) } + assert_close(v[0], 1.0, 1e-7); + assert!(v[1..].iter().all(|&x| x == 0.0)); + } + + #[test] + fn micro_4x2_matches_individual_dots() { + let d = 16; + let p0 = ramp(d, 0.1); + let p1 = ramp(d, -0.2); + let p2 = ramp(d, 0.3); + let p3 = ramp(d, -0.4); + let c0 = ramp(d, 0.5); + let c1 = ramp(d, -0.6); + + // SAFETY: all 6 slices have length 16, a multiple of 8. + let got = unsafe { micro_4x2(&p0, &p1, &p2, &p3, &c0, &c1) }; + + let expected = [ + [ref_dot(&p0, &c0), ref_dot(&p0, &c1)], + [ref_dot(&p1, &c0), ref_dot(&p1, &c1)], + [ref_dot(&p2, &c0), ref_dot(&p2, &c1)], + [ref_dot(&p3, &c0), ref_dot(&p3, &c1)], + ]; + + for (g, e) in got.iter().zip(&expected) { + assert_close(g[0], e[0], 1e-5); + assert_close(g[1], e[1], 1e-5); + } + } + + #[test] + fn micro_4x2_d3072() { + let d = 3072; + let p0 = ramp(d, 0.001); + let p1 = ramp(d, -0.001); + let p2 = ramp(d, 0.002); + let p3 = ramp(d, -0.002); + let c0 = ramp(d, 0.001); + let c1 = ramp(d, -0.001); + + // SAFETY: all 6 slices have length 3072, a multiple of 8. + let got = unsafe { micro_4x2(&p0, &p1, &p2, &p3, &c0, &c1) }; + + let expected = [ + [ref_dot(&p0, &c0), ref_dot(&p0, &c1)], + [ref_dot(&p1, &c0), ref_dot(&p1, &c1)], + [ref_dot(&p2, &c0), ref_dot(&p2, &c1)], + [ref_dot(&p3, &c0), ref_dot(&p3, &c1)], + ]; + + for (g, e) in got.iter().zip(&expected) { + assert_close(g[0], e[0], 1e-3); + assert_close(g[1], e[1], 1e-3); + } + } + + #[test] + fn nearest4_matches_brute_force_even_k() { + let d = 8; + let k = 4; + + // 4 centroids: axis-aligned unit vectors. + let mut centroids = vec![0.0_f32; k * d]; + for i in 0..k { + centroids[i * d + i] = 1.0; + } + + // 4 points, each close to a different centroid. + let mut points: [Vec; 4] = core::array::from_fn(|_| vec![0.0_f32; d]); + for i in 0..4 { + points[i][i] = 10.0; + points[i][(i + 1) % d] = 0.1; + } + + // SAFETY: d=8 (multiple of 8), k=4 > 0, centroids has length k*d, + // all point slices have length d. + let got = unsafe { + nearest4( + &points[0], &points[1], &points[2], &points[3], ¢roids, k, d, + ) + }; + + assert_eq!(got[0].0, 0); + assert_eq!(got[1].0, 1); + assert_eq!(got[2].0, 2); + assert_eq!(got[3].0, 3); + } + + #[test] + fn nearest4_matches_brute_force_odd_k() { + let d = 8; + let k = 3; // odd: exercises the remainder path + + let mut centroids = vec![0.0_f32; k * d]; + for i in 0..k { + centroids[i * d + i] = 1.0; + } + + let mut points: [Vec; 4] = core::array::from_fn(|_| vec![0.0_f32; d]); + points[0][0] = 5.0; + points[1][1] = 5.0; + points[2][2] = 5.0; + points[3][0] = 3.0; // closest to centroid 0 + + // SAFETY: d=8 (multiple of 8), k=3 > 0, centroids has length k*d, + // all point slices have length d. + let got = unsafe { + nearest4( + &points[0], &points[1], &points[2], &points[3], ¢roids, k, d, + ) + }; + + assert_eq!(got[0].0, 0); + assert_eq!(got[1].0, 1); + assert_eq!(got[2].0, 2); + assert_eq!(got[3].0, 0); + } + + #[test] + fn nearest4_k1_all_same() { + let d = 8; + let centroids = ramp(d, 1.0); + let p0 = ramp(d, 0.1); + let p1 = ramp(d, -0.2); + let p2 = ramp(d, 0.3); + let p3 = ramp(d, -0.4); + + // SAFETY: d=8 (multiple of 8), k=1 > 0, centroids has length d, + // all point slices have length d. + let got = unsafe { nearest4(&p0, &p1, &p2, &p3, ¢roids, 1, d) }; + + assert_eq!(got[0].0, 0); + assert_eq!(got[1].0, 0); + assert_eq!(got[2].0, 0); + assert_eq!(got[3].0, 0); + } +} diff --git a/libs/@local/graph/store/src/embedding/mod.rs b/libs/@local/graph/store/src/embedding/mod.rs new file mode 100644 index 00000000000..cefb998fb67 --- /dev/null +++ b/libs/@local/graph/store/src/embedding/mod.rs @@ -0,0 +1,15 @@ +#![expect( + unsafe_code, + dead_code, + clippy::indexing_slicing, + clippy::float_arithmetic, + clippy::min_ident_chars, + clippy::many_single_char_names, + reason = "embedding module is under active development; dead_code is expected until the \ + public API is wired up. Single-char idents (k, n, m, d, x) are standard \ + mathematical notation for clustering." +)] + +pub mod clustering; +pub mod dimension; +pub(crate) mod kernel; diff --git a/libs/@local/graph/store/src/entity/mod.rs b/libs/@local/graph/store/src/entity/mod.rs index f1174879703..df563a365cc 100644 --- a/libs/@local/graph/store/src/entity/mod.rs +++ b/libs/@local/graph/store/src/entity/mod.rs @@ -4,14 +4,14 @@ pub use self::{ EntityQuerySortingToken, EntityQueryToken, }, store::{ - ClosedMultiEntityTypeMap, CreateEntityParams, DeleteEntitiesParams, DeletionScope, - DeletionSummary, DiffEntityParams, DiffEntityResult, EntityPermissions, EntityStore, - EntityValidationType, HasPermissionForEntitiesParams, LinkDeletionBehavior, - PatchEntityParams, QueryConversion, QueryEntitiesParams, QueryEntitiesResponse, - QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SearchEntitiesFilter, - SearchEntitiesParams, SearchEntitiesResponse, SummarizeEntitiesParams, - SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityComponents, - ValidateEntityError, ValidateEntityParams, + ClosedMultiEntityTypeMap, ClusterEntitiesParams, ClusterEntitiesResponse, + CreateEntityParams, DeleteEntitiesParams, DeletionScope, DeletionSummary, DiffEntityParams, + DiffEntityResult, EntityCluster, EntityPermissions, EntityStore, EntityValidationType, + HasPermissionForEntitiesParams, LinkDeletionBehavior, PatchEntityParams, QueryConversion, + QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, + QueryEntitySubgraphResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, + UpdateEntityEmbeddingsParams, ValidateEntityComponents, ValidateEntityError, + ValidateEntityParams, }, validation_report::{ EmptyEntityTypes, EntityRetrieval, EntityTypeRetrieval, EntityTypesError, diff --git a/libs/@local/graph/store/src/entity/store.rs b/libs/@local/graph/store/src/entity/store.rs index 94ff152a36b..1af2a5af1c4 100644 --- a/libs/@local/graph/store/src/entity/store.rs +++ b/libs/@local/graph/store/src/entity/store.rs @@ -36,7 +36,9 @@ use utoipa::{ use crate::{ entity::{EntityQueryCursor, EntityQuerySorting, EntityValidationReport}, entity_type::{EntityTypeResolveDefinitions, IncludeEntityTypeOption}, - error::{CheckPermissionError, DeletionError, InsertionError, QueryError, UpdateError}, + error::{ + CheckPermissionError, ClusterError, DeletionError, InsertionError, QueryError, UpdateError, + }, filter::{Filter, SemanticDistance}, subgraph::{ Subgraph, @@ -525,6 +527,55 @@ impl PatchEntityParams { } } +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct ClusterEntitiesParams { + pub entity_ids: Vec, + /// Desired number of clusters. Clamped to the number of entities with + /// embeddings when that is smaller. + pub cluster_count: u16, + /// Embedding dimension after matryoshka truncation. Must be a positive + /// multiple of 8; values above 3072 are rejected. Defaults to 256. + #[serde(default = "ClusterEntitiesParams::default_dimension")] + pub dimension: u16, + + /// Seed for the random number generator used in clustering. + /// + /// If not provided, a random seed will be used. + pub seed: Option, +} + +impl ClusterEntitiesParams { + const fn default_dimension() -> u16 { + 256 + } +} + +/// One cluster from a spherical k-means run over entity embeddings. +#[derive(Debug, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] +pub struct EntityCluster { + /// Index in `0..cluster_count`. + pub cluster_id: u16, + pub entity_ids: Vec, + /// Unit-normalized centroid with length equal to the requested dimension. + pub centroid: Vec, +} + +/// Result of [`EntityStore::cluster_entities`]. +#[derive(Debug, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] +pub struct ClusterEntitiesResponse { + /// One entry per non-empty cluster. Empty clusters (no points assigned) + /// are omitted. + pub clusters: Vec, + /// Entities from the request that had no stored embedding. + pub missing_embeddings: Vec, +} + #[derive(Debug, Deserialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] #[serde(rename_all = "camelCase", deny_unknown_fields)] @@ -912,6 +963,27 @@ pub trait EntityStore { params: UpdateEntityEmbeddingsParams<'_>, ) -> impl Future>> + Send; + /// Groups entities by embedding similarity using spherical k-means. + /// + /// Each entity's combined embedding is truncated to the requested + /// dimension (matryoshka encoding) before clustering. The returned + /// centroids are unit-normalized and have the same dimension. + /// + /// Entities without a stored embedding are not clustered; they appear + /// in [`ClusterEntitiesResponse::missing_embeddings`]. + /// + /// # Errors + /// + /// Returns [`ClusterError::InvalidDimension`] if the dimension is not a + /// positive multiple of 8, [`ClusterError::DimensionTooLarge`] if it + /// exceeds the stored embedding width, or [`ClusterError::Store`] if the + /// embedding query fails. + fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: ClusterEntitiesParams, + ) -> impl Future>> + Send; + /// Re-indexes the cache for entities. /// /// This is only needed if the entity was changed in place without an update procedure. This is diff --git a/libs/@local/graph/store/src/error.rs b/libs/@local/graph/store/src/error.rs index d9f39229103..fc40ef8daf4 100644 --- a/libs/@local/graph/store/src/error.rs +++ b/libs/@local/graph/store/src/error.rs @@ -68,3 +68,18 @@ pub enum CheckPermissionError { } impl Error for CheckPermissionError {} + +/// Failure to cluster entities by embedding similarity. +#[derive(Debug, derive_more::Display)] +#[display("Could not cluster entities: {_variant}")] +#[must_use] +pub enum ClusterError { + #[display("dimension {dimension} is not a positive multiple of 8")] + InvalidDimension { dimension: u16 }, + #[display("dimension {dimension} exceeds stored embedding dimension {max}")] + DimensionTooLarge { dimension: u16, max: u16 }, + #[display("embedding query failed")] + Store, +} + +impl Error for ClusterError {} diff --git a/libs/@local/graph/store/src/lib.rs b/libs/@local/graph/store/src/lib.rs index 2931d3f0eca..668c46b20b8 100644 --- a/libs/@local/graph/store/src/lib.rs +++ b/libs/@local/graph/store/src/lib.rs @@ -5,6 +5,10 @@ #![feature( // Language Features impl_trait_in_assoc_type, + + // Library Features, + portable_simd, + integer_widen_truncate, )] #![cfg_attr(test, feature( // Language Features @@ -23,6 +27,7 @@ pub mod oauth_provider; pub mod property_type; pub mod user_deletion; +pub mod embedding; pub mod error; pub mod filter; pub mod migration; diff --git a/libs/@local/graph/type-fetcher/src/store.rs b/libs/@local/graph/type-fetcher/src/store.rs index 99a3ad3415c..d310af8f9a6 100644 --- a/libs/@local/graph/type-fetcher/src/store.rs +++ b/libs/@local/graph/type-fetcher/src/store.rs @@ -35,10 +35,10 @@ use hash_graph_store::{ UnarchiveDataTypeParams, UpdateDataTypeEmbeddingParams, UpdateDataTypesParams, }, entity::{ - CreateEntityParams, DeleteEntitiesParams, DeletionSummary, EntityStore, - EntityValidationReport, HasPermissionForEntitiesParams, PatchEntityParams, - QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, - QueryEntitySubgraphResponse, SearchEntitiesParams, SearchEntitiesResponse, + ClusterEntitiesParams, ClusterEntitiesResponse, CreateEntityParams, DeleteEntitiesParams, + DeletionSummary, EntityStore, EntityValidationReport, HasPermissionForEntitiesParams, + PatchEntityParams, QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, + QueryEntitySubgraphResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityParams, }, @@ -50,7 +50,9 @@ use hash_graph_store::{ QueryEntityTypesResponse, SearchEntityTypesParams, SearchEntityTypesResponse, UnarchiveEntityTypeParams, UpdateEntityTypeEmbeddingParams, UpdateEntityTypesParams, }, - error::{CheckPermissionError, DeletionError, InsertionError, QueryError, UpdateError}, + error::{ + CheckPermissionError, ClusterError, DeletionError, InsertionError, QueryError, UpdateError, + }, filter::{Filter, QueryRecord}, pool::StorePool, property_type::{ @@ -1713,6 +1715,14 @@ where self.store.update_entity_embeddings(actor_id, params).await } + async fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: ClusterEntitiesParams, + ) -> Result> { + self.store.cluster_entities(actor_id, params).await + } + async fn reindex_entity_cache(&mut self) -> Result<(), Report> { self.store.reindex_entity_cache().await } diff --git a/tests/graph/integration/postgres/lib.rs b/tests/graph/integration/postgres/lib.rs index 5c2ab899e3f..f9e54f423bb 100644 --- a/tests/graph/integration/postgres/lib.rs +++ b/tests/graph/integration/postgres/lib.rs @@ -890,6 +890,17 @@ impl EntityStore for DatabaseApi<'_> { self.store.reindex_entity_cache().await } + async fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: hash_graph_store::entity::ClusterEntitiesParams, + ) -> Result< + hash_graph_store::entity::ClusterEntitiesResponse, + Report, + > { + self.store.cluster_entities(actor_id, params).await + } + async fn has_permission_for_entities( &self, authenticated_actor: AuthenticatedActor, From a9bcac4d8d78bd3080f3a4de408af41723acf8ab Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 10:07:07 +0200 Subject: [PATCH 02/10] fix: suggestions from code review --- Cargo.lock | 54 ++++++++++--------- Cargo.toml | 1 + libs/@local/graph/api/src/rest/entity.rs | 6 +-- .../store/postgres/knowledge/entity/mod.rs | 32 +++++------ libs/@local/graph/store/Cargo.toml | 3 ++ .../graph/store/src/embedding/dimension.rs | 6 +++ libs/@local/graph/store/src/entity/mod.rs | 3 +- libs/@local/graph/store/src/entity/store.rs | 7 +-- libs/@local/graph/store/src/error.rs | 6 +-- libs/@local/graph/type-fetcher/src/store.rs | 2 +- 10 files changed, 67 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9657b71a390..7c1ada76d6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -201,7 +201,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -225,7 +225,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -1848,7 +1848,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e75b2483e97a5a7da73ac68a05b629f9c53cff58d8ed1c77866079e18b00dba5" dependencies = [ "digest 0.10.7", - "spin 0.10.0", + "spin", ] [[package]] @@ -2213,7 +2213,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccc2776f0c61eca1ca32528f85548abd1a4be8fb53d1b21c013e4f18da1e7090" dependencies = [ "data-encoding", - "syn 2.0.118", + "syn 1.0.109", ] [[package]] @@ -2662,7 +2662,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -2684,7 +2684,7 @@ dependencies = [ "rustc_version", "serde", "serde_core", - "spin 0.12.1", + "spin", "supports-color", "supports-unicode", "thiserror 2.0.18", @@ -3796,6 +3796,9 @@ dependencies = [ "hash-temporal-client", "insta", "postgres-types", + "rand 0.10.1", + "rand_xoshiro", + "rayon", "serde", "serde_json", "simple-mermaid", @@ -4968,7 +4971,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -6298,7 +6301,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -8007,6 +8010,15 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand_xoshiro" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "662effc7698e08ea324d3acccf8d9d7f7bf79b9785e270a174ea36e56900c91d" +dependencies = [ + "rand_core 0.10.1", +] + [[package]] name = "rapidfuzz" version = "0.5.0" @@ -8494,7 +8506,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -8553,7 +8565,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9236,7 +9248,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9275,12 +9287,6 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" -[[package]] -name = "spin" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5231412d905519dca6a5deb0327d407be68d6c941feec004533401d3a0a715" - [[package]] name = "spki" version = "0.7.3" @@ -9371,7 +9377,7 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9659,10 +9665,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.4.3", + "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9728,7 +9734,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8c27177b12a6399ffc08b98f76f7c9a1f4fe9fc967c784c5a071fa8d93cf7e1" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9750,7 +9756,7 @@ dependencies = [ "parking_lot", "rustix", "signal-hook", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9760,7 +9766,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -11261,7 +11267,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 42b78fb5027..fdac17ab287 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -221,6 +221,7 @@ quote = { version = "1.0.41", default-features = fa rand = { version = "0.10.0", default-features = false } rand_core = { version = "0.10.0", default-features = false } rand_distr = { version = "0.6.0", default-features = false } +rand_xoshiro = { version = "0.8.1" } rapidfuzz = { version = "0.5.0", default-features = false } ratatui = { version = "0.30.0" } rayon = { version = "1.11.0", default-features = false } diff --git a/libs/@local/graph/api/src/rest/entity.rs b/libs/@local/graph/api/src/rest/entity.rs index c1210d5be46..1f08fc73634 100644 --- a/libs/@local/graph/api/src/rest/entity.rs +++ b/libs/@local/graph/api/src/rest/entity.rs @@ -17,9 +17,9 @@ use hash_graph_store::{ HasPermissionForEntitiesParams, LinkDataStateError, LinkDataValidationReport, LinkError, LinkTargetError, LinkValidationReport, LinkedEntityError, MetadataValidationReport, PatchEntityParams, PropertyMetadataValidationReport, QueryConversion, - QueryEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, - UnexpectedEntityType, UpdateEntityEmbeddingsParams, ValidateEntityComponents, - ValidateEntityParams, + QueryEntitiesResponse, SearchEntitiesFilter, SearchEntitiesResponse, + SummarizeEntitiesParams, SummarizeEntitiesResponse, UnexpectedEntityType, + UpdateEntityEmbeddingsParams, ValidateEntityComponents, ValidateEntityParams, }, entity_type::EntityTypeResolveDefinitions, pool::StorePool, diff --git a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs index 869bbc46ef8..adfc9383285 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs @@ -3,7 +3,7 @@ mod query; mod read; mod summary; -use alloc::borrow::Cow; +use alloc::{borrow::Cow, collections::BTreeMap}; use core::{borrow::Borrow as _, mem}; use std::collections::{HashMap, HashSet}; @@ -25,7 +25,8 @@ use hash_graph_store::{ EntityQueryPath, EntityQuerySorting, EntityStore, EntityTypeRetrieval, EntityTypesError, EntityValidationReport, EntityValidationType, HasPermissionForEntitiesParams, PatchEntityParams, QueryConversion, QueryEntitiesParams, QueryEntitiesResponse, - QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SummarizeEntitiesParams, + QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SearchEntitiesFilter, + SearchEntitiesParams, SearchEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityComponents, ValidateEntityParams, }, @@ -2554,36 +2555,31 @@ where Ok(permitted_ids) } - #[expect(clippy::too_many_lines)] + #[expect(clippy::too_many_lines, clippy::cast_possible_truncation)] #[tracing::instrument(skip(self, params))] async fn cluster_entities( &self, actor_id: ActorEntityUuid, params: ClusterEntitiesParams, ) -> Result> { - // 3072 fits in u16; compile-time verified. - const { - assert!(Embedding::DIM <= u16::MAX as usize); - } - #[expect( - clippy::cast_possible_truncation, - reason = "guarded by the const assertion above" - )] + const { assert!(Embedding::DIM <= u16::MAX as usize) }; const STORED_DIM: u16 = Embedding::DIM as u16; - let dim = Dimension::new(params.dimension).ok_or_else(|| { + let dimension = Dimension::new(params.dimension.get()).ok_or_else(|| { Report::new(ClusterError::InvalidDimension { dimension: params.dimension, }) + .attach(StatusCode::InvalidArgument) })?; - if dim.get() > STORED_DIM { + if dimension.get() > STORED_DIM { return Err(Report::new(ClusterError::DimensionTooLarge { - dimension: dim.get(), + dimension: dimension.value(), max: STORED_DIM, - })); + }) + .attach(StatusCode::InvalidArgument)); } - let truncated_dim = usize::from(dim.get()); + let truncated_dim = usize::from(dimension.get()); // Filter to entities the actor is allowed to view. let permitted = self @@ -2701,9 +2697,9 @@ where }), ); - let result = hash_graph_store::embedding::clustering::cluster(&flat, dim, &config); + let result = hash_graph_store::embedding::clustering::cluster(&flat, dimension, &config); - let mut groups: HashMap> = HashMap::new(); + let mut groups: BTreeMap> = BTreeMap::new(); for (index, id) in found_ids.iter().enumerate() { groups.entry(result.label(index)).or_default().push(*id); } diff --git a/libs/@local/graph/store/Cargo.toml b/libs/@local/graph/store/Cargo.toml index 825638c1e35..c02ac60c582 100644 --- a/libs/@local/graph/store/Cargo.toml +++ b/libs/@local/graph/store/Cargo.toml @@ -29,6 +29,9 @@ bytes = { workspace = true, optional = true } derive-where = { workspace = true } derive_more = { workspace = true, features = ["display", "error"] } futures = { workspace = true } +rand = { workspace = true } +rand_xoshiro = { workspace = true } +rayon = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } simple-mermaid = { workspace = true } diff --git a/libs/@local/graph/store/src/embedding/dimension.rs b/libs/@local/graph/store/src/embedding/dimension.rs index a3a1a31fe94..0ba85516ee3 100644 --- a/libs/@local/graph/store/src/embedding/dimension.rs +++ b/libs/@local/graph/store/src/embedding/dimension.rs @@ -31,6 +31,12 @@ impl Dimension { pub const fn get(self) -> u16 { self.0.get() } + + /// The raw dimension value as a [`NonZero`]. + #[must_use] + pub const fn value(self) -> NonZero { + self.0 + } } pub const D128: Dimension = Dimension(NonZero::new(128).unwrap()); diff --git a/libs/@local/graph/store/src/entity/mod.rs b/libs/@local/graph/store/src/entity/mod.rs index df563a365cc..e9aa8161500 100644 --- a/libs/@local/graph/store/src/entity/mod.rs +++ b/libs/@local/graph/store/src/entity/mod.rs @@ -9,7 +9,8 @@ pub use self::{ DiffEntityResult, EntityCluster, EntityPermissions, EntityStore, EntityValidationType, HasPermissionForEntitiesParams, LinkDeletionBehavior, PatchEntityParams, QueryConversion, QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, - QueryEntitySubgraphResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, + QueryEntitySubgraphResponse, SearchEntitiesFilter, SearchEntitiesParams, + SearchEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityComponents, ValidateEntityError, ValidateEntityParams, }, diff --git a/libs/@local/graph/store/src/entity/store.rs b/libs/@local/graph/store/src/entity/store.rs index 1af2a5af1c4..91bccfdf907 100644 --- a/libs/@local/graph/store/src/entity/store.rs +++ b/libs/@local/graph/store/src/entity/store.rs @@ -1,4 +1,5 @@ use alloc::borrow::Cow; +use core::num::NonZero; use std::collections::{HashMap, HashSet}; use error_stack::Report; @@ -538,7 +539,7 @@ pub struct ClusterEntitiesParams { /// Embedding dimension after matryoshka truncation. Must be a positive /// multiple of 8; values above 3072 are rejected. Defaults to 256. #[serde(default = "ClusterEntitiesParams::default_dimension")] - pub dimension: u16, + pub dimension: NonZero, /// Seed for the random number generator used in clustering. /// @@ -547,8 +548,8 @@ pub struct ClusterEntitiesParams { } impl ClusterEntitiesParams { - const fn default_dimension() -> u16 { - 256 + const fn default_dimension() -> NonZero { + const { NonZero::new(256).unwrap() } } } diff --git a/libs/@local/graph/store/src/error.rs b/libs/@local/graph/store/src/error.rs index fc40ef8daf4..3f1b684167e 100644 --- a/libs/@local/graph/store/src/error.rs +++ b/libs/@local/graph/store/src/error.rs @@ -1,4 +1,4 @@ -use core::{error::Error, fmt}; +use core::{error::Error, fmt, num::NonZero}; #[derive(Debug)] #[must_use] @@ -75,9 +75,9 @@ impl Error for CheckPermissionError {} #[must_use] pub enum ClusterError { #[display("dimension {dimension} is not a positive multiple of 8")] - InvalidDimension { dimension: u16 }, + InvalidDimension { dimension: NonZero }, #[display("dimension {dimension} exceeds stored embedding dimension {max}")] - DimensionTooLarge { dimension: u16, max: u16 }, + DimensionTooLarge { dimension: NonZero, max: u16 }, #[display("embedding query failed")] Store, } diff --git a/libs/@local/graph/type-fetcher/src/store.rs b/libs/@local/graph/type-fetcher/src/store.rs index d310af8f9a6..87a124b74b0 100644 --- a/libs/@local/graph/type-fetcher/src/store.rs +++ b/libs/@local/graph/type-fetcher/src/store.rs @@ -38,7 +38,7 @@ use hash_graph_store::{ ClusterEntitiesParams, ClusterEntitiesResponse, CreateEntityParams, DeleteEntitiesParams, DeletionSummary, EntityStore, EntityValidationReport, HasPermissionForEntitiesParams, PatchEntityParams, QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, - QueryEntitySubgraphResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, + QueryEntitySubgraphResponse, SearchEntitiesParams, SearchEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityParams, }, From 56d81ef44159767ab89de8089b56ce3d981d3da6 Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 10:13:16 +0200 Subject: [PATCH 03/10] fix: spawn blocking for clustering --- .../src/store/postgres/knowledge/entity/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs index adfc9383285..31d970077fb 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs @@ -2697,7 +2697,11 @@ where }), ); - let result = hash_graph_store::embedding::clustering::cluster(&flat, dimension, &config); + let result = tokio::task::spawn_blocking(move || { + hash_graph_store::embedding::clustering::cluster(&flat, dimension, &config) + }) + .await + .change_context(ClusterError::Store)?; let mut groups: BTreeMap> = BTreeMap::new(); for (index, id) in found_ids.iter().enumerate() { From 2a2b1e432c2c69e11a7ae181fc9a587743f1acfe Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 10:14:49 +0200 Subject: [PATCH 04/10] fix: regenerate --- libs/@local/graph/api/openapi/openapi.json | 135 +++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/libs/@local/graph/api/openapi/openapi.json b/libs/@local/graph/api/openapi/openapi.json index 6fa9755b618..bd0485a85b8 100644 --- a/libs/@local/graph/api/openapi/openapi.json +++ b/libs/@local/graph/api/openapi/openapi.json @@ -1580,6 +1580,54 @@ } } }, + "/entities/embeddings/clusters": { + "post": { + "tags": [ + "Graph", + "Entity" + ], + "operationId": "cluster_entities", + "parameters": [ + { + "name": "X-Authenticated-User-Actor-Id", + "in": "header", + "description": "The ID of the actor which is used to authorize the request", + "required": true, + "schema": { + "$ref": "#/components/schemas/ActorEntityUuid" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClusterEntitiesParams" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Clusters of entities by embedding similarity", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClusterEntitiesResponse" + } + } + } + }, + "422": { + "description": "Provided request body is invalid" + }, + "500": { + "description": "Store error occurred" + } + } + } + }, "/entities/permissions": { "post": { "tags": [ @@ -3752,6 +3800,62 @@ "propertyName": "kind" } }, + "ClusterEntitiesParams": { + "type": "object", + "required": [ + "entityIds", + "clusterCount" + ], + "properties": { + "clusterCount": { + "type": "integer", + "format": "int32", + "description": "Desired number of clusters. Clamped to the number of entities with\nembeddings when that is smaller.", + "minimum": 0 + }, + "dimension": { + "$ref": "#/components/schemas/NonZero" + }, + "entityIds": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityId" + } + }, + "seed": { + "type": "integer", + "format": "int64", + "description": "Seed for the random number generator used in clustering.\n\nIf not provided, a random seed will be used.", + "nullable": true, + "minimum": 0 + } + }, + "additionalProperties": false + }, + "ClusterEntitiesResponse": { + "type": "object", + "description": "Result of [`EntityStore::cluster_entities`].", + "required": [ + "clusters", + "missingEmbeddings" + ], + "properties": { + "clusters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityCluster" + }, + "description": "One entry per non-empty cluster. Empty clusters (no points assigned)\nare omitted." + }, + "missingEmbeddings": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityId" + }, + "description": "Entities from the request that had no stored embedding." + } + } + }, "CommonQueryEntityTypesParams": { "type": "object", "required": [ @@ -4517,6 +4621,37 @@ } } }, + "EntityCluster": { + "type": "object", + "description": "One cluster from a spherical k-means run over entity embeddings.", + "required": [ + "clusterId", + "entityIds", + "centroid" + ], + "properties": { + "centroid": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "description": "Unit-normalized centroid with length equal to the requested dimension." + }, + "clusterId": { + "type": "integer", + "format": "int32", + "description": "Index in `0..cluster_count`.", + "minimum": 0 + }, + "entityIds": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityId" + } + } + } + }, "EntityDeletionProvenance": { "type": "object", "required": [ From 24eb8f364d8703304efbc48f12a2f78cb40559ae Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 10:22:47 +0200 Subject: [PATCH 05/10] fix: unnest sql query --- .../postgres-store/src/store/postgres/knowledge/entity/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs index 31d970077fb..c863e5d9ac0 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs @@ -2621,8 +2621,8 @@ where embedding FROM entity_embeddings e WHERE e.property IS NULL - AND (e.web_id, e.entity_uuid) IN (SELECT unnest($1::uuid[]), \ - unnest($2::uuid[]))" + AND (e.web_id, e.entity_uuid) IN (SELECT * FROM unnest($1::uuid[], \ + $2::uuid[]))" ), [ &web_ids as &(dyn ToSql + Sync), From 5db85e2f79c70e8885c626f6263781fd49611390 Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 10:31:18 +0200 Subject: [PATCH 06/10] fix: lints --- libs/@local/graph/store/src/embedding/clustering.rs | 11 ----------- libs/@local/graph/store/src/embedding/kernel.rs | 11 ++++++----- libs/@local/graph/store/src/embedding/mod.rs | 6 ++---- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/libs/@local/graph/store/src/embedding/clustering.rs b/libs/@local/graph/store/src/embedding/clustering.rs index f48da06398a..45839919a68 100644 --- a/libs/@local/graph/store/src/embedding/clustering.rs +++ b/libs/@local/graph/store/src/embedding/clustering.rs @@ -37,12 +37,6 @@ pub struct Config { } impl Config { - /// Creates a configuration for `k` clusters, drawing the seed from `rng`. - #[must_use] - pub(crate) fn for_k(k: u16, mut rng: impl Rng) -> Self { - Self::for_k_with_seed(k, rng.random()) - } - /// Creates a configuration for `k` clusters with a fixed seed. /// /// Defaults: 30 max iterations, 5 restarts, 1e-4 convergence tolerance, @@ -112,11 +106,6 @@ impl Clustering { pub fn label(&self, entity: usize) -> u16 { self.labels[entity] } - - /// Returns a mutable reference to the cluster label for point `entity`. - fn label_mut(&mut self, entity: usize) -> &mut u16 { - &mut self.labels[entity] - } } // TODO: I wonder if we can make this allocation less diff --git a/libs/@local/graph/store/src/embedding/kernel.rs b/libs/@local/graph/store/src/embedding/kernel.rs index 0090f63c4c5..7abaf1c8b5b 100644 --- a/libs/@local/graph/store/src/embedding/kernel.rs +++ b/libs/@local/graph/store/src/embedding/kernel.rs @@ -1,3 +1,8 @@ +#![expect( + clippy::inline_always, + reason = "while usually discouraged, SIMD operations need to be inlined, as otherwise we \ + spill SIMD registers, see the SIMD documentation." +)] use core::simd::{Simd, f32x8, num::SimdFloat as _}; /// Fused multiply-add when the target has native FMA, separate mul+add otherwise. @@ -241,11 +246,7 @@ pub(crate) unsafe fn normalize(value: &mut [f32]) { /// # Safety /// * all six slices have length `d` /// * `d` is a multiple of 8 -#[expect( - clippy::inline_always, - reason = "micro-kernel must inline into nearest4 to keep accumulators in registers" -)] -#[inline(always)] +#[inline(always)] // micro-kernel must inline nearest4 to keep accumulators in registers pub(crate) unsafe fn micro_4x2( p0: &[f32], p1: &[f32], diff --git a/libs/@local/graph/store/src/embedding/mod.rs b/libs/@local/graph/store/src/embedding/mod.rs index cefb998fb67..d008b247ce0 100644 --- a/libs/@local/graph/store/src/embedding/mod.rs +++ b/libs/@local/graph/store/src/embedding/mod.rs @@ -1,13 +1,11 @@ #![expect( unsafe_code, - dead_code, clippy::indexing_slicing, clippy::float_arithmetic, clippy::min_ident_chars, clippy::many_single_char_names, - reason = "embedding module is under active development; dead_code is expected until the \ - public API is wired up. Single-char idents (k, n, m, d, x) are standard \ - mathematical notation for clustering." + reason = "embedding module is under active development. Single-char idents (k, n, m, d, x) \ + are standard mathematical notation for clustering." )] pub mod clustering; From 605c9dd20df7bb8490df6de987f3919a34060311 Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 11:02:03 +0200 Subject: [PATCH 07/10] fix: lints --- libs/@local/graph/store/src/embedding/kernel.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/@local/graph/store/src/embedding/kernel.rs b/libs/@local/graph/store/src/embedding/kernel.rs index 7abaf1c8b5b..d5610052694 100644 --- a/libs/@local/graph/store/src/embedding/kernel.rs +++ b/libs/@local/graph/store/src/embedding/kernel.rs @@ -7,8 +7,8 @@ use core::simd::{Simd, f32x8, num::SimdFloat as _}; /// Fused multiply-add when the target has native FMA, separate mul+add otherwise. /// -/// On aarch64, FMA is part of the base NEON instruction set (`fmla`). -/// On x86_64, FMA requires the `fma` target feature (`vfmadd`); without it, +/// On `aarch64`, FMA is part of the base NEON instruction set (`fmla`). +/// On `x86_64`, FMA requires the `fma` target feature (`vfmadd`); without it, /// `StdFloat::mul_add` falls back to a per-lane `fmaf` libc call which /// destroys throughput. The non-FMA path uses a plain multiply and add /// (`vmulps` + `vaddps`) instead. From 2953664655fc018ae73dd365b153d489b46168ec Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Wed, 1 Jul 2026 11:06:51 +0200 Subject: [PATCH 08/10] fix: openapi schema --- libs/@local/graph/api/openapi/openapi.json | 7 ++++++- libs/@local/graph/store/src/entity/store.rs | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/@local/graph/api/openapi/openapi.json b/libs/@local/graph/api/openapi/openapi.json index bd0485a85b8..52fb63f2f02 100644 --- a/libs/@local/graph/api/openapi/openapi.json +++ b/libs/@local/graph/api/openapi/openapi.json @@ -3814,7 +3814,12 @@ "minimum": 0 }, "dimension": { - "$ref": "#/components/schemas/NonZero" + "type": "integer", + "format": "int32", + "description": "Embedding dimension after matryoshka truncation. Must be a positive\nmultiple of 8; values above 3072 are rejected. Defaults to 256.", + "default": 256, + "example": 256, + "minimum": 1 }, "entityIds": { "type": "array", diff --git a/libs/@local/graph/store/src/entity/store.rs b/libs/@local/graph/store/src/entity/store.rs index 91bccfdf907..e6c43cdea12 100644 --- a/libs/@local/graph/store/src/entity/store.rs +++ b/libs/@local/graph/store/src/entity/store.rs @@ -539,6 +539,7 @@ pub struct ClusterEntitiesParams { /// Embedding dimension after matryoshka truncation. Must be a positive /// multiple of 8; values above 3072 are rejected. Defaults to 256. #[serde(default = "ClusterEntitiesParams::default_dimension")] + #[cfg_attr(feature = "utoipa", schema(value_type = u16, minimum = 1, default = 256, example = 256))] pub dimension: NonZero, /// Seed for the random number generator used in clustering. From 9fa07e159317b3891c4be077e21cf484f2ae35cf Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Fri, 3 Jul 2026 13:09:22 +0200 Subject: [PATCH 09/10] fix: docs --- libs/@local/graph/store/src/embedding/clustering.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/@local/graph/store/src/embedding/clustering.rs b/libs/@local/graph/store/src/embedding/clustering.rs index 45839919a68..cec24ed18ab 100644 --- a/libs/@local/graph/store/src/embedding/clustering.rs +++ b/libs/@local/graph/store/src/embedding/clustering.rs @@ -9,8 +9,8 @@ use super::{dimension::Dimension, kernel}; /// Parameters for k-means clustering. /// -/// Use [`Config::for_k`] or [`Config::for_k_with_seed`] to construct with -/// reasonable defaults, then override individual fields as needed. +/// Use [`Config::for_k_with_seed`] to construct with reasonable defaults, then override individual +/// fields as needed. pub struct Config { /// Number of clusters. pub k: u16, From 3287384deeb1a9dc389290004ef378b05530de0f Mon Sep 17 00:00:00 2001 From: Bilal Mahmoud <7252775+indietyp@users.noreply.github.com> Date: Fri, 3 Jul 2026 13:11:31 +0200 Subject: [PATCH 10/10] fix: docs --- libs/@local/graph/store/src/embedding/kernel.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/@local/graph/store/src/embedding/kernel.rs b/libs/@local/graph/store/src/embedding/kernel.rs index d5610052694..1b971bc23b1 100644 --- a/libs/@local/graph/store/src/embedding/kernel.rs +++ b/libs/@local/graph/store/src/embedding/kernel.rs @@ -320,9 +320,8 @@ pub(crate) unsafe fn micro_4x2( /// [`micro_4x2`] tiled kernel. /// /// Returns `(centroid_index, raw_dot_product)` for each of the 4 points. -/// The raw dot product is **not** a distance; the caller must convert via -/// [`squared_chord_distance`](super::clustering::squared_chord_distance) -/// if needed. +/// The raw dot product is **not** a distance; and must be converted via +/// using the chord distance formula. /// /// # Safety ///