diff --git a/Cargo.lock b/Cargo.lock index 9657b71a390..af6e6862631 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -201,7 +201,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -225,7 +225,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -1848,7 +1848,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e75b2483e97a5a7da73ac68a05b629f9c53cff58d8ed1c77866079e18b00dba5" dependencies = [ "digest 0.10.7", - "spin 0.10.0", + "spin", ] [[package]] @@ -2213,7 +2213,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccc2776f0c61eca1ca32528f85548abd1a4be8fb53d1b21c013e4f18da1e7090" dependencies = [ "data-encoding", - "syn 2.0.118", + "syn 1.0.109", ] [[package]] @@ -2662,7 +2662,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -2684,7 +2684,7 @@ dependencies = [ "rustc_version", "serde", "serde_core", - "spin 0.12.1", + "spin", "supports-color", "supports-unicode", "thiserror 2.0.18", @@ -3784,6 +3784,8 @@ name = "hash-graph-store" version = "0.0.0" dependencies = [ "bytes", + "codspeed-criterion-compat", + "darwin-kperf-criterion", "derive-where", "derive_more", "error-stack", @@ -3796,6 +3798,9 @@ dependencies = [ "hash-temporal-client", "insta", "postgres-types", + "rand 0.10.1", + "rand_xoshiro", + "rayon", "serde", "serde_json", "simple-mermaid", @@ -4968,7 +4973,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -6298,7 +6303,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -8007,6 +8012,15 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand_xoshiro" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "662effc7698e08ea324d3acccf8d9d7f7bf79b9785e270a174ea36e56900c91d" +dependencies = [ + "rand_core 0.10.1", +] + [[package]] name = "rapidfuzz" version = "0.5.0" @@ -8494,7 +8508,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -8553,7 +8567,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9236,7 +9250,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9275,12 +9289,6 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" -[[package]] -name = "spin" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5231412d905519dca6a5deb0327d407be68d6c941feec004533401d3a0a715" - [[package]] name = "spki" version = "0.7.3" @@ -9371,7 +9379,7 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9659,10 +9667,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.4.3", + "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9728,7 +9736,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8c27177b12a6399ffc08b98f76f7c9a1f4fe9fc967c784c5a071fa8d93cf7e1" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9750,7 +9758,7 @@ dependencies = [ "parking_lot", "rustix", "signal-hook", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -9760,7 +9768,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -11261,7 +11269,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 42b78fb5027..fdac17ab287 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -221,6 +221,7 @@ quote = { version = "1.0.41", default-features = fa rand = { version = "0.10.0", default-features = false } rand_core = { version = "0.10.0", default-features = false } rand_distr = { version = "0.6.0", default-features = false } +rand_xoshiro = { version = "0.8.1" } rapidfuzz = { version = "0.5.0", default-features = false } ratatui = { version = "0.30.0" } rayon = { version = "1.11.0", default-features = false } diff --git a/libs/@local/graph/api/openapi/openapi.json b/libs/@local/graph/api/openapi/openapi.json index 6fa9755b618..affb8fa9222 100644 --- a/libs/@local/graph/api/openapi/openapi.json +++ b/libs/@local/graph/api/openapi/openapi.json @@ -1580,6 +1580,54 @@ } } }, + "/entities/embeddings/clusters": { + "post": { + "tags": [ + "Graph", + "Entity" + ], + "operationId": "cluster_entities", + "parameters": [ + { + "name": "X-Authenticated-User-Actor-Id", + "in": "header", + "description": "The ID of the actor which is used to authorize the request", + "required": true, + "schema": { + "$ref": "#/components/schemas/ActorEntityUuid" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClusterEntitiesParams" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Clusters of entities by embedding similarity", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClusterEntitiesResponse" + } + } + } + }, + "422": { + "description": "Provided request body is invalid" + }, + "500": { + "description": "Store error occurred" + } + } + } + }, "/entities/permissions": { "post": { "tags": [ @@ -3752,6 +3800,73 @@ "propertyName": "kind" } }, + "ClusterEntitiesParams": { + "type": "object", + "required": [ + "entityIds", + "clusterCount" + ], + "properties": { + "clusterCount": { + "type": "integer", + "format": "int32", + "description": "Desired number of clusters. Clamped to the number of entities with\nembeddings when that is smaller.", + "minimum": 0 + }, + "dimension": { + "type": "integer", + "format": "int32", + "description": "Embedding dimension after matryoshka truncation. Must be a positive\nmultiple of 8; values above 3072 are rejected. Defaults to 256.", + "default": 256, + "example": 256, + "minimum": 1 + }, + "entityIds": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityId" + } + }, + "seed": { + "type": "integer", + "format": "int64", + "description": "Seed for the random number generator used in clustering.\n\nIf not provided, a random seed will be used.", + "nullable": true, + "minimum": 0 + } + }, + "additionalProperties": false + }, + "ClusterEntitiesResponse": { + "type": "object", + "description": "Result of [`EntityStore::cluster_entities`].", + "required": [ + "clusters", + "missingEmbeddings", + "inertia" + ], + "properties": { + "clusters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityCluster" + }, + "description": "One entry per non-empty cluster. Empty clusters (no points assigned)\nare omitted." + }, + "inertia": { + "type": "number", + "format": "float", + "description": "Sum of squared chord distances from every clustered entity to its\nassigned centroid. Lower is tighter; comparable across runs over the\nsame entities, e.g. to choose a cluster count. `0.0` when nothing was\nclustered." + }, + "missingEmbeddings": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityId" + }, + "description": "Entities from the request that had no stored embedding." + } + } + }, "CommonQueryEntityTypesParams": { "type": "object", "required": [ @@ -4517,6 +4632,37 @@ } } }, + "EntityCluster": { + "type": "object", + "description": "One cluster from a spherical k-means run over entity embeddings.", + "required": [ + "clusterId", + "entityIds", + "centroid" + ], + "properties": { + "centroid": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "description": "Unit-normalized centroid with length equal to the requested dimension." + }, + "clusterId": { + "type": "integer", + "format": "int32", + "description": "Index in `0..cluster_count`.", + "minimum": 0 + }, + "entityIds": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityId" + } + } + } + }, "EntityDeletionProvenance": { "type": "object", "required": [ diff --git a/libs/@local/graph/api/src/rest/entity.rs b/libs/@local/graph/api/src/rest/entity.rs index dcc4846f143..1f08fc73634 100644 --- a/libs/@local/graph/api/src/rest/entity.rs +++ b/libs/@local/graph/api/src/rest/entity.rs @@ -10,16 +10,16 @@ use hash_graph_postgres_store::store::error::{EntityDoesNotExist, RaceConditionO use hash_graph_store::{ self, entity::{ - ClosedMultiEntityTypeMap, CreateEntityParams, DiffEntityParams, DiffEntityResult, - EntityPermissions, EntityQueryCursor, EntityQuerySortingRecord, EntityQuerySortingToken, - EntityQueryToken, EntityStore, EntityTypesError, EntityValidationReport, - EntityValidationType, HasPermissionForEntitiesParams, LinkDataStateError, - LinkDataValidationReport, LinkError, LinkTargetError, LinkValidationReport, - LinkedEntityError, MetadataValidationReport, PatchEntityParams, - PropertyMetadataValidationReport, QueryConversion, QueryEntitiesResponse, - SearchEntitiesFilter, SearchEntitiesResponse, SummarizeEntitiesParams, - SummarizeEntitiesResponse, UnexpectedEntityType, UpdateEntityEmbeddingsParams, - ValidateEntityComponents, ValidateEntityParams, + ClosedMultiEntityTypeMap, ClusterEntitiesParams, ClusterEntitiesResponse, + CreateEntityParams, DiffEntityParams, DiffEntityResult, EntityCluster, EntityPermissions, + EntityQueryCursor, EntityQuerySortingRecord, EntityQuerySortingToken, EntityQueryToken, + EntityStore, EntityTypesError, EntityValidationReport, EntityValidationType, + HasPermissionForEntitiesParams, LinkDataStateError, LinkDataValidationReport, LinkError, + LinkTargetError, LinkValidationReport, LinkedEntityError, MetadataValidationReport, + PatchEntityParams, PropertyMetadataValidationReport, QueryConversion, + QueryEntitiesResponse, SearchEntitiesFilter, SearchEntitiesResponse, + SummarizeEntitiesParams, SummarizeEntitiesResponse, UnexpectedEntityType, + UpdateEntityEmbeddingsParams, ValidateEntityComponents, ValidateEntityParams, }, entity_type::EntityTypeResolveDefinitions, pool::StorePool, @@ -98,6 +98,7 @@ use crate::rest::{ summarize_entities, patch_entity, update_entity_embeddings, + cluster_entities, diff_entity, ), components( @@ -115,6 +116,9 @@ use crate::rest::{ Embedding, UpdateEntityEmbeddingsParams, EntityEmbedding, + ClusterEntitiesParams, + ClusterEntitiesResponse, + EntityCluster, EntityQueryToken, PatchEntityParams, @@ -228,7 +232,12 @@ impl EntityResource { .route("/bulk", post(create_entities::)) .route("/diff", post(diff_entity::)) .route("/validate", post(validate_entity::)) - .route("/embeddings", post(update_entity_embeddings::)) + .nest( + "/embeddings", + Router::new() + .route("/", post(update_entity_embeddings::)) + .route("/clusters", post(cluster_entities::)), + ) .route("/permissions", post(has_permission_for_entities::)) .route("/search", post(search_entities::)) .nest( @@ -780,6 +789,42 @@ where .map_err(report_to_response) } +#[utoipa::path( + post, + path = "/entities/embeddings/clusters", + tag = "Entity", + params( + ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), + ), + responses( + (status = 200, content_type = "application/json", description = "Clusters of entities by embedding similarity", body = ClusterEntitiesResponse), + (status = 422, content_type = "text/plain", description = "Provided request body is invalid"), + + (status = 500, description = "Store error occurred"), + ), + request_body = ClusterEntitiesParams, +)] +async fn cluster_entities( + AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, + store_pool: Extension>, + temporal_client: Extension>>, + Json(params): Json, +) -> Result, BoxedResponse> +where + S: StorePool + Send + Sync, +{ + let store = store_pool + .acquire(temporal_client.0) + .await + .map_err(report_to_response)?; + + store + .cluster_entities(actor_id, params) + .await + .map_err(report_to_response) + .map(Json) +} + #[utoipa::path( post, path = "/entities/diff", diff --git a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs index 93e56376c68..23d760c76e2 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/knowledge/entity/mod.rs @@ -2,7 +2,8 @@ mod delete; mod query; mod read; mod summary; -use alloc::borrow::Cow; + +use alloc::{borrow::Cow, collections::BTreeMap}; use core::{borrow::Borrow as _, mem}; use std::collections::{HashMap, HashSet}; @@ -17,18 +18,22 @@ use hash_graph_authorization::policies::{ store::{PolicyCreationParams, PrincipalStore as _}, }; use hash_graph_store::{ + embedding::dimension::Dimension, entity::{ - CreateEntityParams, DeleteEntitiesParams, DeletionSummary, EmptyEntityTypes, - EntityPermissions, EntityQueryCursor, EntityQueryPath, EntityQuerySorting, EntityStore, - EntityTypeRetrieval, EntityTypesError, EntityValidationReport, EntityValidationType, - HasPermissionForEntitiesParams, PatchEntityParams, QueryConversion, QueryEntitiesParams, - QueryEntitiesResponse, QueryEntitySubgraphParams, QueryEntitySubgraphResponse, - SearchEntitiesFilter, SearchEntitiesParams, SearchEntitiesResponse, - SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, - ValidateEntityComponents, ValidateEntityParams, + ClusterEntitiesParams, ClusterEntitiesResponse, CreateEntityParams, DeleteEntitiesParams, + DeletionSummary, EmptyEntityTypes, EntityCluster, EntityPermissions, EntityQueryCursor, + EntityQueryPath, EntityQuerySorting, EntityStore, EntityTypeRetrieval, EntityTypesError, + EntityValidationReport, EntityValidationType, HasPermissionForEntitiesParams, + PatchEntityParams, QueryConversion, QueryEntitiesParams, QueryEntitiesResponse, + QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SearchEntitiesFilter, + SearchEntitiesParams, SearchEntitiesResponse, SummarizeEntitiesParams, + SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityComponents, + ValidateEntityParams, }, entity_type::{EntityTypeStore as _, IncludeEntityTypeOption}, - error::{CheckPermissionError, DeletionError, InsertionError, QueryError, UpdateError}, + error::{ + CheckPermissionError, ClusterError, DeletionError, InsertionError, QueryError, UpdateError, + }, filter::{ Filter, FilterExpression, FilterExpressionList, Parameter, ParameterList, protection::transform_filter, @@ -2549,6 +2554,180 @@ where Ok(permitted_ids) } + + #[expect(clippy::too_many_lines, clippy::cast_possible_truncation)] + #[tracing::instrument(skip(self, params))] + async fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: ClusterEntitiesParams, + ) -> Result> { + const { assert!(Embedding::DIM <= u16::MAX as usize) }; + const STORED_DIM: u16 = Embedding::DIM as u16; + + let dimension = Dimension::new(params.dimension.get()).ok_or_else(|| { + Report::new(ClusterError::InvalidDimension { + dimension: params.dimension, + }) + .attach(StatusCode::InvalidArgument) + })?; + + if dimension.get() > STORED_DIM { + return Err(Report::new(ClusterError::DimensionTooLarge { + dimension: dimension.value(), + max: STORED_DIM, + }) + .attach(StatusCode::InvalidArgument)); + } + let truncated_dim = usize::from(dimension.get()); + + // Filter to entities the actor is allowed to view. + let permitted = self + .has_permission_for_entities( + AuthenticatedActor::from(actor_id), + HasPermissionForEntitiesParams { + action: ActionName::ViewEntity, + entity_ids: Cow::Borrowed(¶ms.entity_ids), + temporal_axes: QueryTemporalAxesUnresolved::TransactionTime { + pinned: PinnedTemporalAxisUnresolved::new(None), + variable: VariableTemporalAxisUnresolved::new(None, None), + }, + include_drafts: false, + }, + ) + .await + .change_context(ClusterError::Store)?; + + let permitted_ids: Vec = params + .entity_ids + .iter() + .filter(|&id| permitted.contains_key(id)) + .copied() + .collect(); + + let entity_uuids: Vec = permitted_ids.iter().map(|id| id.entity_uuid).collect(); + let web_ids: Vec = permitted_ids.iter().map(|id| id.web_id).collect(); + + // Truncate server-side via `subvector` so postgres only sends + // `truncated_dim`-dimensional vectors over the wire. + // + // Matryoshka truncation shortens the vectors without re-normalizing; + // that is fine here because spherical k-means normalizes internally + // (it works with inverse norms), so no `l2_normalize` is needed. + let row_stream = self + .as_client() + .query_raw( + &format!( + "SELECT + e.web_id, + e.entity_uuid, + subvector(e.embedding, 1, {truncated_dim})::vector({truncated_dim}) AS \ + embedding + FROM entity_embeddings e + WHERE e.property IS NULL + AND (e.web_id, e.entity_uuid) IN (SELECT * FROM unnest($1::uuid[], \ + $2::uuid[]))" + ), + [ + &web_ids as &(dyn ToSql + Sync), + &entity_uuids as &(dyn ToSql + Sync), + ], + ) + .instrument(tracing::info_span!( + "cluster_entities.embeddings", + otel.kind = "client", + db.system = "postgresql", + peer.service = "Postgres", + )) + .await + .change_context(ClusterError::Store)?; + + let mut row_stream = core::pin::pin!(row_stream); + + let mut flat: Vec = Vec::with_capacity(permitted_ids.len() * truncated_dim); + let mut found_ids: Vec = Vec::with_capacity(permitted_ids.len()); + + while let Some(row) = row_stream + .try_next() + .await + .change_context(ClusterError::Store)? + { + let web_id: WebId = row.get(0); + let entity_uuid: EntityUuid = row.get(1); + let embedding: Embedding<'_> = row.get(2); + + flat.extend(embedding.iter()); + + found_ids.push(EntityId { + web_id, + entity_uuid, + draft_id: None, + }); + } + + // Every requested entity not in a cluster goes into + // `missing_embeddings`, whether due to permissions or no embedding. + // Distinguishing the two would leak permission information. + let found_set: HashSet<(WebId, EntityUuid)> = found_ids + .iter() + .map(|id| (id.web_id, id.entity_uuid)) + .collect(); + let missing_embeddings: Vec = params + .entity_ids + .into_iter() + .filter(|id| !found_set.contains(&(id.web_id, id.entity_uuid))) + .collect(); + + if found_ids.is_empty() || params.cluster_count == 0 { + return Ok(ClusterEntitiesResponse { + clusters: Vec::new(), + missing_embeddings, + inertia: 0.0, + }); + } + + let config = hash_graph_store::embedding::clustering::Config::for_k_with_seed( + params.cluster_count, + params.seed.unwrap_or_else(|| { + std::time::SystemTime::UNIX_EPOCH + .elapsed() + .map_or(0, |elapsed| { + #[expect( + clippy::cast_possible_truncation, + reason = "seed only needs entropy, truncation is fine" + )] + let seed = elapsed.as_nanos() as u64; + seed + }) + }), + ); + + let result = tokio::task::spawn_blocking(move || { + hash_graph_store::embedding::clustering::cluster(&flat, dimension, &config) + }) + .await + .change_context(ClusterError::Store)?; + + let mut groups: BTreeMap> = BTreeMap::new(); + for (index, id) in found_ids.iter().enumerate() { + groups.entry(result.label(index)).or_default().push(*id); + } + + let clusters = groups + .into_iter() + .map(|(cluster_id, entity_ids)| EntityCluster { + cluster_id, + entity_ids, + centroid: result.centroid(cluster_id).to_vec(), + }) + .collect(); + + Ok(ClusterEntitiesResponse { + clusters, + missing_embeddings, + inertia: result.inertia, + }) + } } #[derive(Debug)] diff --git a/libs/@local/graph/store/Cargo.toml b/libs/@local/graph/store/Cargo.toml index 825638c1e35..afd99bc632c 100644 --- a/libs/@local/graph/store/Cargo.toml +++ b/libs/@local/graph/store/Cargo.toml @@ -29,6 +29,9 @@ bytes = { workspace = true, optional = true } derive-where = { workspace = true } derive_more = { workspace = true, features = ["display", "error"] } futures = { workspace = true } +rand = { workspace = true } +rand_xoshiro = { workspace = true } +rayon = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } simple-mermaid = { workspace = true } @@ -37,14 +40,20 @@ tracing = { workspace = true } uuid = { workspace = true, features = ["v4"] } [dev-dependencies] -hash-codegen = { workspace = true } -insta = { workspace = true } -tokio = { workspace = true, features = ["macros"] } +codspeed-criterion-compat = { workspace = true } +darwin-kperf-criterion = { workspace = true, features = ["codspeed"] } +hash-codegen = { workspace = true } +insta = { workspace = true } +tokio = { workspace = true, features = ["macros"] } [[test]] name = "codegen" required-features = ["codegen"] +[[bench]] +name = "embedding" +harness = false + [features] codegen = ["dep:specta", "type-system/codegen", "hash-graph-authorization/codegen"] utoipa = ["hash-graph-temporal-versioning/utoipa", "type-system/utoipa", "dep:utoipa"] diff --git a/libs/@local/graph/store/benches/embedding.rs b/libs/@local/graph/store/benches/embedding.rs new file mode 100644 index 00000000000..13094b45589 --- /dev/null +++ b/libs/@local/graph/store/benches/embedding.rs @@ -0,0 +1,237 @@ +//! Benchmarks for the embedding k-means module. +//! +//! Two groups: +//! +//! * `embedding/kernel/*` — single-threaded SIMD micro-kernels, measured in retired instructions +//! via Apple PMCs (near-deterministic; requires root on macOS) with an automatic wall-clock +//! fallback on other platforms. +//! * `embedding/cluster/*` — end-to-end [`cluster`] runs. Always wall-clock, because the work is +//! spread across the rayon pool and per-thread instruction counts would only see the calling +//! thread. +//! +//! [`cluster`]: hash_graph_store::embedding::clustering::cluster +#![expect( + unsafe_code, + clippy::float_arithmetic, + clippy::indexing_slicing, + clippy::integer_division, + clippy::integer_division_remainder_used, + clippy::min_ident_chars, + clippy::significant_drop_tightening, + reason = "benchmarks exercise the unsafe SIMD kernels directly and build float test data; \ + single-char idents (k, n, d) are standard mathematical notation for clustering; the \ + drop-tightening warning originates inside `criterion_group!`" +)] + +use core::hint::black_box; + +use codspeed_criterion_compat::{ + BenchmarkId, Criterion, criterion_group, criterion_main, measurement::Measurement, +}; +use hash_graph_store::embedding::{ + clustering::{Config, cluster}, + dimension::Dimension, + kernel, +}; +use rand::{RngExt as _, SeedableRng as _}; +use rand_xoshiro::Xoshiro256PlusPlus; + +/// Uniform random values in `[-1, 1)`. +fn random_vec(len: usize, seed: u64) -> Vec { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + core::iter::repeat_with(|| rng.random_range(-1.0..1.0)) + .take(len) + .collect() +} + +/// Uniform random values in `[0.1, 1)`, guaranteed positive so repeated +/// accumulation saturates at infinity instead of producing NaNs. +fn random_positive_vec(len: usize, seed: u64) -> Vec { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + core::iter::repeat_with(|| rng.random_range(0.1..1.0)) + .take(len) + .collect() +} + +/// Well-separated blobs: `k` clusters of `points_per_cluster` points in +/// `d`-dimensional space, each with a dominant axis. Mirrors the shape of +/// real embedding workloads better than uniform noise: the fit converges +/// instead of always exhausting `max_iters`. +fn blobs(points_per_cluster: usize, k: usize, d: usize, seed: u64) -> Vec { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + let mut data = vec![0.0_f32; points_per_cluster * k * d]; + + for (index, row) in data.chunks_exact_mut(d).enumerate() { + let axis = (index / points_per_cluster) % d; + row[axis] = 10.0; + for value in row.iter_mut() { + *value += rng.random_range(-0.01..0.01); + } + } + + data +} + +const KERNEL_DIMS: &[usize] = &[256, 1536, 3072]; + +fn bench_dot(criterion: &mut Criterion) { + let mut group = criterion.benchmark_group("embedding/kernel/dot"); + + for &d in KERNEL_DIMS { + let lhs = random_vec(d, 1); + let rhs = random_vec(d, 2); + + group.bench_with_input(BenchmarkId::from_parameter(d), &d, |bencher, _| { + // SAFETY: both slices have length `d`, a multiple of 8. + bencher.iter(|| unsafe { kernel::dot(black_box(&lhs), black_box(&rhs)) }); + }); + } + + group.finish(); +} + +fn bench_add_scaled_into(criterion: &mut Criterion) { + let mut group = criterion.benchmark_group("embedding/kernel/add_scaled_into"); + + for &d in KERNEL_DIMS { + let src = random_positive_vec(d, 3); + let mut dst = random_positive_vec(d, 4); + + group.bench_with_input(BenchmarkId::from_parameter(d), &d, |bencher, _| { + // SAFETY: both slices have length `d`, a multiple of 8. + bencher.iter(|| unsafe { + kernel::add_scaled_into(black_box(&mut dst), black_box(&src), black_box(0.5)); + }); + }); + } + + group.finish(); +} + +fn bench_micro_4x2(criterion: &mut Criterion) { + let mut group = criterion.benchmark_group("embedding/kernel/micro_4x2"); + + for &d in KERNEL_DIMS { + let points: Vec> = (0..4).map(|seed| random_vec(d, 10 + seed)).collect(); + let c0 = random_vec(d, 20); + let c1 = random_vec(d, 21); + + group.bench_with_input(BenchmarkId::from_parameter(d), &d, |bencher, _| { + // SAFETY: all six slices have length `d`, a multiple of 8. + bencher.iter(|| unsafe { + kernel::micro_4x2( + black_box(&points[0]), + black_box(&points[1]), + black_box(&points[2]), + black_box(&points[3]), + black_box(&c0), + black_box(&c1), + ) + }); + }); + } + + group.finish(); +} + +macro_rules! nz { + ($expr:expr) => { + const { ::core::num::NonZero::new($expr).unwrap() } + }; +} + +fn bench_nearest4(criterion: &mut Criterion) { + let mut group = criterion.benchmark_group("embedding/kernel/nearest4"); + + // k = 15 exercises the odd-k remainder path. + for &(d, k) in &[ + (256, nz!(15)), + (256, nz!(16)), + (256, nz!(64)), + (1536, nz!(16)), + (3072, nz!(16)), + ] { + let points: Vec> = (0..4).map(|seed| random_vec(d, 30 + seed)).collect(); + let centroids = random_vec(k.get() * d, 40); + + group.bench_with_input( + BenchmarkId::new(format!("d{d}"), k), + &(d, k), + |bencher, _| { + // SAFETY: point slices have length `d` (multiple of 8), + // centroids has length `k * d`, and `k > 0`. + bencher.iter(|| unsafe { + kernel::nearest4( + black_box(&points[0]), + black_box(&points[1]), + black_box(&points[2]), + black_box(&points[3]), + black_box(¢roids), + black_box(k), + black_box(d), + ) + }); + }, + ); + } + + group.finish(); +} + +fn bench_cluster(criterion: &mut Criterion) { + let mut group = criterion.benchmark_group("embedding/cluster"); + group.sample_size(10); + + let dimension = Dimension::new(256).expect("256 is a positive multiple of 8"); + + // (n, k): n = 10k exercises the subsampled fit (m = 8192) plus the + // full-data refinement; n = 50k shifts the weight onto the full-data + // passes. + for &(n, k) in &[ + (10_000_usize, 8_u16), + (10_000, 32), + (10_000, 128), + (50_000, 32), + ] { + let data = blobs(n / usize::from(k), usize::from(k), 256, 7); + let config = Config::for_k_with_seed(k, 42); + + group.bench_with_input( + BenchmarkId::new(format!("n{n}_d256"), k), + &(n, k), + |bencher, _| { + bencher.iter(|| cluster(black_box(&data), black_box(dimension), &config)); + }, + ); + } + + group.finish(); +} + +fn kernel_measurement() -> Criterion { + use core::time::Duration; + + // Retired instructions on Apple Silicon (needs root there), wall-clock + // fallback everywhere else. Instruction counts are near-deterministic, + // so short windows and small samples suffice. + Criterion::default() + .with_measurement( + darwin_kperf_criterion::HardwareCounter::instructions() + .expect("instruction counting requires root on Apple Silicon (run under sudo)"), + ) + .warm_up_time(Duration::from_millis(500)) + .measurement_time(Duration::from_secs(1)) + .sample_size(20) +} + +criterion_group!( + name = kernel; + config = kernel_measurement(); + targets = bench_dot, bench_add_scaled_into, bench_micro_4x2, bench_nearest4 +); +criterion_group!( + name = clustering; + config = Criterion::default(); + targets = bench_cluster +); +criterion_main!(kernel, clustering); diff --git a/libs/@local/graph/store/src/embedding/clustering.rs b/libs/@local/graph/store/src/embedding/clustering.rs new file mode 100644 index 00000000000..0ff02308ee7 --- /dev/null +++ b/libs/@local/graph/store/src/embedding/clustering.rs @@ -0,0 +1,1626 @@ +use alloc::borrow::Cow; +use core::{cmp, num::NonZero}; +use std::collections::HashSet; + +use rand::{Rng, RngExt as _, SeedableRng as _}; +use rand_xoshiro::Xoshiro256PlusPlus; +use rayon::{ + iter::{ + IndexedParallelIterator as _, IntoParallelIterator as _, IntoParallelRefIterator as _, + IntoParallelRefMutIterator as _, ParallelIterator as _, + }, + slice::{ParallelSlice as _, ParallelSliceMut as _}, +}; + +use super::{dimension::Dimension, kernel}; + +/// Parameters for k-means clustering. +/// +/// Use [`Config::for_k_with_seed`] to construct with reasonable defaults, then override individual +/// fields as needed. +pub struct Config { + /// Number of clusters. + pub k: u16, + + /// Maximum Lloyd iterations per run before declaring convergence. + pub max_iters: NonZero, + + /// Number of independent restarts. The run with the lowest inertia wins. + pub n_init: NonZero, + + /// Convergence tolerance: a run stops early when the relative change in + /// inertia between iterations falls below this value. + pub tol: f32, + + /// Maximum number of points sampled during k-means++ seeding. + /// Capped to avoid quadratic seeding cost on very large datasets. + pub sample_cap: usize, + + /// Base seed for the PRNG. + /// + /// Runs with the same seed, input, and configuration produce identical + /// labels and centroids. + pub seed: u64, + + /// Number of points processed per batch in the parallel passes. + /// Values larger than the number of points are clamped. + pub chunk: NonZero, +} + +impl Config { + /// Creates a configuration for `k` clusters with a fixed seed. + /// + /// Defaults: 30 max iterations, 5 restarts, 1e-4 convergence tolerance, + /// sample cap of min(256k, 8192), chunk size 256. + #[must_use] + pub fn for_k_with_seed(k: u16, seed: u64) -> Self { + Self { + k, + max_iters: const { NonZero::new(30).unwrap() }, + n_init: const { NonZero::new(5).unwrap() }, + tol: 1e-4, + sample_cap: cmp::min(256 * usize::from(k), 8192), + seed, + chunk: const { NonZero::new(256).unwrap() }, + } + } +} + +/// Result of spherical k-means clustering. +/// +/// `centroids` is a flat `k * d` row-major buffer where `d` is the +/// embedding [`Dimension`]. Centroid `i` occupies +/// `centroids[i * d .. (i + 1) * d]`. +pub struct Clustering { + pub dimension: Dimension, + + /// Flat centroid matrix, `k * d` elements in row-major order. + /// + /// Centroids are unit-normalized, with one exception: a cluster whose + /// members are all zero-norm points keeps a zero centroid, since there + /// is no direction to normalize. + pub centroids: Box<[f32]>, + + /// Cluster assignment for each input point, values in `0..k`. + /// + /// When [`cluster`] ran with `k == 0` (requested or clamped) the labels + /// are all-zero placeholders and there are no centroids to index. + pub labels: Box<[u16]>, + + /// Sum of squared chord distances from every input point to its assigned + /// centroid, measured against the final centroids. Lower is tighter; + /// comparable across runs on the same input, e.g. for choosing `k`. + /// `0.0` when `k == 0` or the input is empty. + /// + /// The value is precise only up to floating-point summation order: + /// repeated runs over identical input can differ in the final bits. + pub inertia: f32, +} + +impl Clustering { + /// Allocates a zeroed clustering for `k` centroids over `n` points. + fn new(k: u16, n: usize, d: Dimension) -> Self { + // SAFETY: All-zero bits are valid for `f32` (IEEE 754 positive zero) + // and for `u16` (the integer 0). `Box::new_zeroed_slice` allocates + // zeroed memory of the correct layout, so `assume_init` is sound. + let centroids: Box<[f32]> = + unsafe { Box::new_zeroed_slice((k as usize) * (d.get() as usize)).assume_init() }; + // SAFETY: All-zero bits are valid for `u16` (the integer 0). + let labels: Box<[u16]> = unsafe { Box::new_zeroed_slice(n).assume_init() }; + + Self { + centroids, + labels, + dimension: d, + inertia: 0.0, + } + } + + /// Returns the `D`-dimensional slice for centroid `cluster`. + /// + /// # Panics + /// + /// Panics if `cluster` is not below the number of centroids. + #[must_use] + pub fn centroid(&self, cluster: u16) -> &[f32] { + &self.centroids[cluster as usize * (self.dimension.get() as usize) + ..(cluster + 1) as usize * (self.dimension.get() as usize)] + } + + /// Returns a mutable `D`-dimensional slice for centroid `cluster`. + fn centroid_mut(&mut self, cluster: u16) -> &mut [f32] { + &mut self.centroids[cluster as usize * (self.dimension.get() as usize) + ..(cluster + 1) as usize * (self.dimension.get() as usize)] + } + + /// Returns the cluster label for point `entity`. + /// + /// # Panics + /// + /// Panics if `entity` is not below the number of input points. + #[must_use] + pub fn label(&self, entity: usize) -> u16 { + self.labels[entity] + } +} + +/// Draws `m` distinct indices uniformly at random from `0..n` in O(m) time +/// and memory (Robert Floyd's sampling algorithm). +/// +/// The result is sorted and deterministic for a given RNG state. +fn sample_indices(n: usize, m: usize, mut rng: impl Rng) -> Vec { + debug_assert!(m <= n); + + let mut selected: HashSet = HashSet::with_capacity(m); + + for upper in n - m..n { + let candidate = rng.random_range(0..=upper); + + if !selected.insert(candidate) { + // `candidate` was already drawn in an earlier round. Earlier + // rounds only drew from `0..upper`, so `upper` itself is fresh. + selected.insert(upper); + } + } + + let mut indices: Vec = selected.into_iter().collect(); + // Sorting erases the hash set's nondeterministic iteration order and + // turns the caller's gather into a forward walk over `x`. + indices.sort_unstable(); + indices +} + +/// Squared chord distance between a point and a unit centroid. +/// +/// For a unit centroid `c` and a point with inverse norm `inv`, the cosine +/// similarity is `dot(point, c) * inv`. The squared chord distance is +/// `2 - 2 * similarity`, which lies in `[0, 4]` and equals `||u - c||²` +/// when `u` is the unit-normalized point. +/// +/// Returns `0.0` for zero-norm points (`point_inv_norm == 0.0`). +/// +/// This is a squared distance. Do not square it again for D² sampling. +#[inline] +fn squared_chord_distance(dot: f32, point_inv_norm: f32) -> f32 { + if point_inv_norm == 0.0 { + return 0.0; + } + + let similarity = (dot * point_inv_norm).clamp(-1.0, 1.0); + + 2.0_f32.mul_add(-similarity, 2.0).max(0.0) +} + +/// Finds the nearest centroid to `point` and returns its index and spherical +/// distance. +/// +/// # Safety +/// +/// * `point.len() == d` +/// * `centroids.len() == k * d` +/// * `d` is a multiple of 8 (guaranteed by [`Dimension`]). +#[inline] +#[must_use] +pub(crate) unsafe fn nearest_centroid( + point: &[f32], + point_inv_norm: f32, + centroids: &[f32], + k: NonZero, + d: usize, +) -> (u16, f32) { + debug_assert_eq!(point.len(), d); + debug_assert_eq!(centroids.len(), k.get() * d); + + // SAFETY: the caller guarantees these preconditions. The hints let the + // compiler elide bounds checks on the centroid slicing inside the loop. + unsafe { + core::hint::assert_unchecked(point.len() == d); + core::hint::assert_unchecked(centroids.len() == k.get() * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut best = 0; + let mut best_dot = f32::NEG_INFINITY; + + for cluster in 0..k.get() { + let start = cluster * d; + let centroid = ¢roids[start..start + d]; + + // SAFETY: `point` and `centroid` both have length `D`, and `D` is a + // multiple of 8 (guaranteed by Dimension). + let dot = unsafe { kernel::dot(point, centroid) }; + + #[expect( + clippy::cast_possible_truncation, + reason = "cluster < k, and k originates from Config::k (u16)" + )] + if dot > best_dot { + best = cluster as u16; + best_dot = dot; + } + } + + (best, squared_chord_distance(best_dot, point_inv_norm)) +} + +/// Assigns one chunk of points during Lloyd iterations: writes each point's +/// nearest centroid into `labels` and its squared chord distance into +/// `distances`. +/// +/// # Safety +/// +/// * `points.len() == labels.len() * d` +/// * `inv_norms.len() == labels.len()` +/// * `distances.len() == labels.len()` +/// * `centroids.len() == k * d` +/// * `d` is a multiple of 8 +unsafe fn lloyd_assign( + k: NonZero, + d: usize, + centroids: &[f32], + points: &[f32], + inv_norms: &[f32], + labels: &mut [u16], + distances: &mut [f32], +) { + let count = labels.len(); + + // SAFETY: the caller guarantees the length relations; the hints let the + // compiler elide bounds checks in the tiled loop below. + unsafe { + core::hint::assert_unchecked(points.len() == count * d); + core::hint::assert_unchecked(inv_norms.len() == count); + core::hint::assert_unchecked(distances.len() == count); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut i = 0; + while i + 4 <= count { + let p0 = &points[i * d..i * d + d]; + let p1 = &points[(i + 1) * d..(i + 1) * d + d]; + let p2 = &points[(i + 2) * d..(i + 2) * d + d]; + let p3 = &points[(i + 3) * d..(i + 3) * d + d]; + + // SAFETY: each point length d, centroids length k*d, + // k > 0, d a multiple of 8 (guaranteed by Dimension). + let nearest = unsafe { kernel::nearest4(p0, p1, p2, p3, centroids, k, d) }; + + for offset in 0..4 { + labels[i + offset] = nearest[offset].0; + distances[i + offset] = + squared_chord_distance(nearest[offset].1, inv_norms[i + offset]); + } + i += 4; + } + + while i < count { + let point = &points[i * d..i * d + d]; + + // SAFETY: point length d, centroids length k*d, k > 0, d mult of 8. + let (label, distance) = unsafe { nearest_centroid(point, inv_norms[i], centroids, k, d) }; + labels[i] = label; + distances[i] = distance; + i += 1; + } +} + +/// Per-restart scratch and state for one k-means fit on the sample. +/// +/// Restarts run in parallel, so each owns its buffers. [`Restart::new`] +/// hands them back zeroed. +struct Restart { + k: NonZero, + m: usize, + d: usize, + + /// Centroids for this restart, `k * d` elements. + centroids: Box<[f32]>, + /// Per-cluster accumulator for centroid recomputation, `k * d` elements. + sums: Box<[f32]>, + /// Per-cluster point count for the empty-cluster check. + counts: Box<[usize]>, + /// Per-sample-point cluster assignment. + labels: Box<[u16]>, + /// Per-sample-point distance scratch. + point_distances: Box<[f32]>, + /// Tracks which sample points have been selected as seeds. + selected: Box<[bool]>, +} + +impl Restart { + fn new(k: NonZero, m: usize, d: usize) -> Self { + // SAFETY: all-zero bits are valid for `f32` (IEEE 754 +0.0), `usize` (0), `u16` (0), and + // `bool` (false). `Box::new_zeroed_slice` allocates zeroed memory of the correct + // layout for each type, so `assume_init` is sound in every case. + let centroids = unsafe { Box::<[f32]>::new_zeroed_slice(k.get() * d).assume_init() }; + // SAFETY: see above + let sums = unsafe { Box::<[f32]>::new_zeroed_slice(k.get() * d).assume_init() }; + // SAFETY: see above + let counts = unsafe { Box::<[usize]>::new_zeroed_slice(k.get()).assume_init() }; + // SAFETY: see above + let labels = unsafe { Box::<[u16]>::new_zeroed_slice(m).assume_init() }; + // SAFETY: see above + let point_distances = unsafe { Box::<[f32]>::new_zeroed_slice(m).assume_init() }; + // SAFETY: see above + let selected = unsafe { Box::<[bool]>::new_zeroed_slice(m).assume_init() }; + + Self { + k, + m, + d, + centroids, + sums, + counts, + labels, + point_distances, + selected, + } + } + + /// Runs one restart: k-means++ seeding followed by Lloyd iterations. + /// + /// Returns the sample inertia of the fitted centroids. + fn run( + &mut self, + sample: &[f32], + chunk: usize, + row_chunk: usize, + sample_inv_norms: &[f32], + seed: u64, + config: &Config, + ) -> f32 { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + + // `new` zeroes everything else; the distance scratch must start at + // infinity so the first seeding pass overwrites every entry. + self.point_distances.fill(f32::INFINITY); + + self.seed_plusplus(sample, sample_inv_norms, &mut rng); + self.lloyd(sample, chunk, row_chunk, sample_inv_norms, config) + } + + /// k-means++ D² weighted seeding. Picks `k` initial centroids from the + /// sample, each chosen with probability proportional to its squared + /// distance from the nearest already-chosen centroid. + fn seed_plusplus(&mut self, sample: &[f32], sample_inv_norms: &[f32], rng: &mut impl Rng) { + let &mut Self { d, k, m, .. } = self; + let mut point = rng.random_range(0..m); + + for cluster in 0..k.get() { + let centroid_start = cluster * d; + let point_start = point * d; + + self.centroids[centroid_start..centroid_start + d] + .copy_from_slice(&sample[point_start..point_start + d]); + + // SAFETY: centroid rows have length `d`, and `d` is a multiple of 8. + unsafe { + kernel::normalize(&mut self.centroids[centroid_start..centroid_start + d]); + } + + self.selected[point] = true; + + // The last centroid needs no D² update: those distances would + // only be used to sample a further seed. + if cluster + 1 == k.get() { + break; + } + + let centroid = &self.centroids[centroid_start..centroid_start + d]; + + // Per-element writes only, so the pass is deterministic under + // rayon; the D² total is summed sequentially below. + sample + .par_chunks_exact(d) + .zip(sample_inv_norms.par_iter()) + .zip(self.point_distances.par_iter_mut()) + .enumerate() + .for_each(|(index, ((point, &inv_norm), closest))| { + if self.selected[index] { + *closest = 0.0; + return; + } + + // SAFETY: `point` and `centroid` both have length `D`, and + // `D` is a multiple of 8 (guaranteed by Dimension). + let dot = unsafe { kernel::dot(point, centroid) }; + let distance = squared_chord_distance(dot, inv_norm); + + if distance < *closest { + *closest = distance; + } + }); + + let total: f32 = self.point_distances.iter().sum(); + + point = if total.is_finite() && total > 0.0 { + let mut target = rng.random_range(0.0..total); + let mut sampled = None; + let mut last_positive = 0; + + for (index, &distance) in self.point_distances.iter().enumerate() { + if distance <= 0.0 { + continue; + } + + last_positive = index; + target -= distance; + + if target <= 0.0 { + sampled = Some(index); + break; + } + } + + // Rounding can leave `target` marginally positive after the last bucket; + // fall back to the last point with positive mass. + sampled.unwrap_or(last_positive) + } else { + // Degenerate geometry: every remaining point coincides with a seed. + // Pick uniformly among the unselected points. + let remaining = self.selected.iter().filter(|selected| !**selected).count(); + let mut target = rng.random_range(0..remaining); + let mut sampled = 0; + + for (index, selected) in self.selected.iter().copied().enumerate() { + if selected { + continue; + } + + if target == 0 { + sampled = index; + break; + } + + target -= 1; + } + + sampled + }; + } + } + + /// Runs Lloyd iterations on the sample until convergence or `max_iters`. + /// + /// Returns the final inertia (sum of distances to assigned centroids). + fn lloyd( + &mut self, + sample: &[f32], + chunk: usize, + row_chunk: usize, + sample_inv_norms: &[f32], + config: &Config, + ) -> f32 { + let &mut Self { d, k, .. } = self; + let mut previous_inertia = f32::INFINITY; + let mut inertia = f32::INFINITY; + + for _ in 0..config.max_iters.get() { + // Assignment: labels and per-point distances. + // Trivially deterministic under rayon, as writes are per element. + sample + .par_chunks(row_chunk) + .zip(sample_inv_norms.par_chunks(chunk)) + .zip(self.labels.par_chunks_mut(chunk)) + .zip(self.point_distances.par_chunks_mut(chunk)) + .for_each(|(((points, inv_norms), labels), distances)| { + // SAFETY: `par_chunks(row_chunk)` with `row_chunk == chunk * d` + // pairs `labels.len()` labels, distances, and inv norms with + // `labels.len() * d` floats of points. `self.centroids` has + // length `k * d`, and `d` is a multiple of 8 (guaranteed by + // Dimension). + unsafe { + lloyd_assign(k, d, &self.centroids, points, inv_norms, labels, distances); + }; + }); + + inertia = self.point_distances.iter().sum(); + + // SAFETY: `sample.len() == m * d` with `m` labels, sums is + // `k * d` with `k` counts, and `d` is a multiple of 8 + // (guaranteed by Dimension). + unsafe { + accumulate_clusters( + sample, + &self.labels, + Some(sample_inv_norms), + &mut self.sums, + &mut self.counts, + d, + ); + } + + for cluster in 0..k.get() { + if self.counts[cluster] == 0 { + continue; + } + + let start = cluster * d; + + // Normalization is scale-invariant, so the raw sum gives the same direction as the + // average. + self.centroids[start..start + d].copy_from_slice(&self.sums[start..start + d]); + + // SAFETY: centroid rows have length `d`, and `d` is a multiple of 8 + // (guaranteed by Dimension). + unsafe { + kernel::normalize(&mut self.centroids[start..start + d]); + } + } + + let reseeded = self.reinit_empty_clusters(sample); + + // Skip the convergence check when a cluster was just reseeded: + // the reseeded centroid hasn't had an assignment pass yet, so + // breaking now would waste the reinit. + if !reseeded && previous_inertia.is_finite() { + let relative_change = + (previous_inertia - inertia).abs() / previous_inertia.max(f32::EPSILON); + + if relative_change <= config.tol { + break; + } + } + + previous_inertia = inertia; + } + + inertia + } + + /// Reinitializes empty clusters from the sample point farthest from its + /// assigned centroid, using the distances stored by the assignment pass. + /// + /// After relocating a point its stored distance is zeroed and its label + /// updated, so subsequent empty clusters in the same pass pick different + /// points. + #[expect( + clippy::cast_possible_truncation, + reason = "cluster index < k, and k originates from Config::k (u16)" + )] + fn reinit_empty_clusters(&mut self, sample: &[f32]) -> bool { + let &mut Self { d, k, .. } = self; + let mut reseeded = false; + + for cluster in 0..k.get() { + if self.counts[cluster] != 0 { + continue; + } + + reseeded = true; + + let mut farthest_idx = 0; + let mut farthest_dist = -1.0_f32; + + for (index, &distance) in self.point_distances.iter().enumerate() { + if distance > farthest_dist { + farthest_dist = distance; + farthest_idx = index; + } + } + + let point_start = farthest_idx * d; + let centroid_start = cluster * d; + + self.centroids[centroid_start..centroid_start + d] + .copy_from_slice(&sample[point_start..point_start + d]); + + // SAFETY: centroid rows have length `d`, a multiple of 8. + unsafe { + kernel::normalize(&mut self.centroids[centroid_start..centroid_start + d]); + } + + self.labels[farthest_idx] = cluster as u16; + self.point_distances[farthest_idx] = 0.0; + } + + reseeded + } +} + +/// Recomputes per-cluster sums and counts from labeled points. +/// +/// The result is impervious to any thread schedule order. +/// +/// `inv_norms` supplies precomputed inverse norms; pass `None` to compute +/// them on the fly. +/// +/// Zero-norm points are counted but contribute nothing to the sums. +/// +/// # Safety +/// +/// * `points.len() == labels.len() * d` +/// * `sums.len() == counts.len() * d` +/// * `inv_norms`, when provided, has one entry per label +/// * `d` is a multiple of 8 +unsafe fn accumulate_clusters( + points: &[f32], + labels: &[u16], + inv_norms: Option<&[f32]>, + sums: &mut [f32], + counts: &mut [usize], + d: usize, +) { + // A `debug_assert` only: an `assert_unchecked` here would be sound (the + // length is a documented precondition), but to elide the + // `inv_norms[index]` bounds check the fact would have to survive into + // the rayon closure, and that check is noise next to the `d`-wide + // kernel call it precedes. + debug_assert!(inv_norms.is_none_or(|norms| norms.len() == labels.len())); + + sums.par_chunks_exact_mut(d) + .zip(counts.par_iter_mut()) + .enumerate() + .for_each(|(cluster, (sum, count))| { + sum.fill(0.0); + *count = 0; + + for (index, (point, &label)) in points.chunks_exact(d).zip(labels).enumerate() { + if usize::from(label) != cluster { + continue; + } + + *count += 1; + + let inv_norm = inv_norms.map_or_else( + || { + // SAFETY: `point` has length `d`, a multiple of 8 + // (guaranteed by the caller). + let norm = unsafe { kernel::dot(point, point) }.sqrt(); + + if norm > 0.0 { norm.recip() } else { 0.0 } + }, + |inv_norms| inv_norms[index], + ); + + if inv_norm == 0.0 { + continue; + } + + // SAFETY: `sum` and `point` both have length `d`, and `d` is + // a multiple of 8 (guaranteed by the caller). + unsafe { + kernel::add_scaled_into(sum, point, inv_norm); + } + } + }); +} + +/// Labels one parallel chunk: each point gets its nearest centroid. +/// +/// # Safety +/// +/// * `points.len() == labels.len() * d` +/// * `centroids.len() >= k * d` +/// * `d` is a multiple of 8 +unsafe fn label_chunk( + centroids: &[f32], + k: NonZero, + d: usize, + points: &[f32], + labels: &mut [u16], +) { + let count = labels.len(); + + // SAFETY: each parallel chunk pairs `count` labels with `count * d` + // floats of point data; `d` is a multiple of 8 (guaranteed by Dimension). + unsafe { + core::hint::assert_unchecked(points.len() == count * d); + core::hint::assert_unchecked(centroids.len() >= k.get() * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut i = 0; + while i + 4 <= count { + let p0 = &points[i * d..i * d + d]; + let p1 = &points[(i + 1) * d..(i + 1) * d + d]; + let p2 = &points[(i + 2) * d..(i + 2) * d + d]; + let p3 = &points[(i + 3) * d..(i + 3) * d + d]; + + // SAFETY: each point length d, centroids length k*d, k > 0, + // d a multiple of 8 (guaranteed by Dimension). + let nearest = unsafe { kernel::nearest4(p0, p1, p2, p3, centroids, k, d) }; + + labels[i] = nearest[0].0; + labels[i + 1] = nearest[1].0; + labels[i + 2] = nearest[2].0; + labels[i + 3] = nearest[3].0; + i += 4; + } + + while i < count { + let point = &points[i * d..i * d + d]; + // SAFETY: point length d, centroids length k*d, k > 0, d mult of 8. + let (label, _) = unsafe { nearest_centroid(point, 1.0, centroids, k, d) }; + labels[i] = label; + i += 1; + } +} + +/// Labels one parallel chunk against the final centroids and returns its +/// inertia contribution. Inverse norms are computed on the fly. +/// +/// # Safety +/// +/// * `points.len() == labels.len() * d` +/// * `centroids.len() >= k * d` +/// * `d` is a multiple of 8 +unsafe fn score_chunk( + centroids: &[f32], + k: NonZero, + d: usize, + points: &[f32], + labels: &mut [u16], +) -> f32 { + debug_assert_eq!(points.len(), labels.len() * d); + + // SAFETY: The caller must ensure `points.len() == labels.len() * d`. + unsafe { + core::hint::assert_unchecked(points.len() == labels.len() * d); + } + + let count = labels.len(); + let mut inertia = 0.0_f32; + + let mut i = 0; + while i + 4 <= count { + let p0 = &points[i * d..i * d + d]; + let p1 = &points[(i + 1) * d..(i + 1) * d + d]; + let p2 = &points[(i + 2) * d..(i + 2) * d + d]; + let p3 = &points[(i + 3) * d..(i + 3) * d + d]; + + // SAFETY: each point length d, centroids length k*d, k > 0, + // d a multiple of 8 (guaranteed by Dimension). + let nearest = unsafe { kernel::nearest4(p0, p1, p2, p3, centroids, k, d) }; + let ps = [p0, p1, p2, p3]; + + for offset in 0..4 { + labels[i + offset] = nearest[offset].0; + + // SAFETY: point length d, a multiple of 8. + let norm = unsafe { kernel::dot(ps[offset], ps[offset]) }.sqrt(); + let inv_norm = if norm > 0.0 { norm.recip() } else { 0.0 }; + inertia += squared_chord_distance(nearest[offset].1, inv_norm); + } + i += 4; + } + + while i < count { + let point = &points[i * d..i * d + d]; + + // SAFETY: point length d, a multiple of 8. + let norm = unsafe { kernel::dot(point, point) }.sqrt(); + let inv_norm = if norm > 0.0 { norm.recip() } else { 0.0 }; + + // SAFETY: point length d, centroids length k*d, k > 0, d mult of 8. + let (label, distance) = unsafe { nearest_centroid(point, inv_norm, centroids, k, d) }; + labels[i] = label; + inertia += distance; + i += 1; + } + + inertia +} + +/// Labels every point with its nearest centroid. +/// +/// # Safety +/// +/// * `x.len() == n * d` for some `n` +/// * `clustering.centroids.len() == k * d` +/// * `clustering.labels.len() == n` +/// * `d` is a multiple of 8 +unsafe fn reassign( + x: &[f32], + centroids: &[f32], + labels: &mut [u16], + k: NonZero, + d: usize, + chunk: usize, + row_chunk: usize, +) { + x.par_chunks(row_chunk) + .zip(labels.par_chunks_mut(chunk)) + .for_each(|(points, labels)| { + // SAFETY: `par_chunks(row_chunk)` with `row_chunk = chunk * d` + // ensures `points.len() == labels.len() * d`. Centroids and k + // are valid from the caller. + unsafe { + label_chunk(centroids, k, d, points, labels); + } + }); +} + +/// Labels every point with its nearest centroid and returns the total +/// inertia. +/// +/// Labels are exact; the inertia is precise only up to floating-point +/// summation order. +/// +/// # Safety +/// +/// * `x.len() == n * d` for some `n` +/// * `clustering.centroids.len() == k * d` +/// * `clustering.labels.len() == n` +/// * `d` is a multiple of 8 +unsafe fn reassign_scored( + x: &[f32], + centroids: &[f32], + labels: &mut [u16], + k: NonZero, + d: usize, + chunk: usize, + row_chunk: usize, +) -> f32 { + // Unordered parallel reduction: the grouping follows rayon's scheduling, so the sum is not + // bit-stable. An ordered reduction would need to collect per-chunk partials, costing an + // allocation per call. + x.par_chunks(row_chunk) + .zip(labels.par_chunks_mut(chunk)) + .map(|(points, labels)| { + // SAFETY: `par_chunks(row_chunk)` with `row_chunk = chunk * d` + // ensures `points.len() == labels.len() * d`. Centroids and k + // are valid from the caller. + unsafe { score_chunk(centroids, k, d, points, labels) } + }) + .sum() +} + +/// Assigns all `n` points to their nearest centroid, recomputes centroids +/// from the full population, and re-labels against the final centroids. +/// Returns the full-data inertia. +/// +/// `sums` and `counts` are accumulator scratch; their contents on entry are +/// irrelevant. +/// +/// # Safety +/// +/// * `x.len() == n * d` for some `n` +/// * `clustering.centroids.len() == k * d` +/// * `clustering.labels.len() == n` +/// * `sums.len() == k * d` and `counts.len() == k` +/// * `d` is a multiple of 8 +unsafe fn assign( + x: &[f32], + clustering: &mut Clustering, + k: NonZero, + chunk: usize, + row_chunk: usize, + sums: &mut [f32], + counts: &mut [usize], +) -> f32 { + let d = clustering.dimension.get() as usize; + + // 1. Label all points against the sample-fitted centroids. + // SAFETY: forwarded from the caller. + unsafe { + reassign( + x, + &clustering.centroids, + &mut clustering.labels, + k, + d, + chunk, + row_chunk, + ); + } + + // 2. Recompute centroids from the full population. + // SAFETY: `x.len() == n * d` with `n` labels, sums is `k * d` with `k` + // counts, and `d` is a multiple of 8 (guaranteed by Dimension). + unsafe { + accumulate_clusters(x, &clustering.labels, None, sums, counts, d); + } + + for (cluster, count) in counts.iter_mut().enumerate() { + if *count == 0 { + continue; + } + + let start = cluster * d; + + #[expect( + clippy::cast_possible_truncation, + reason = "cluster < k and k originates from Config::k (u16)" + )] + let centroid = clustering.centroid_mut(cluster as u16); + // Normalization is scale-invariant, so the raw sum gives the same + // direction as the average. + centroid.copy_from_slice(&sums[start..start + d]); + + // SAFETY: centroid length d, a multiple of 8. + unsafe { + kernel::normalize(centroid); + } + } + + // 3. Final labels and inertia against the recomputed centroids. + // SAFETY: centroids were just recomputed in place; same invariants hold. + unsafe { + reassign_scored( + x, + &clustering.centroids, + &mut clustering.labels, + k, + d, + chunk, + row_chunk, + ) + } +} + +/// Runs spherical k-means over a flat row-major embedding matrix. +/// +/// `x` contains `n` points of `dimension` floats each, laid out +/// contiguously. Returns cluster assignments, unit-normalized centroids, and +/// the full-data inertia. +/// +/// Given the same input and configuration, the returned labels and +/// centroids are identical across runs; the inertia is precise only up to +/// floating-point summation order (see [`Clustering::inertia`]). +/// +/// Zero-norm points do not influence centroids, and are always assigned to +/// cluster 0 at distance 0. If a cluster consists solely of zero-norm points, +/// its centroid is zero; see [`Clustering::centroids`]. +/// +/// If `config.k == 0` or `x` is empty there is nothing to fit: the result +/// has no centroids, all-zero placeholder labels, and an inertia of `0.0`. +/// +/// # Panics +/// +/// Panics if `x.len()` is not a multiple of `dimension`. +#[must_use] +#[expect(clippy::integer_division_remainder_used, clippy::integer_division)] +pub fn cluster(x: &[f32], dimension: Dimension, config: &Config) -> Clustering { + let d = dimension.get() as usize; + assert!(x.len().is_multiple_of(d)); + + let n = x.len() / d; + let k = cmp::min(config.k, n.saturating_truncate()); + + let mut clustering = Clustering::new(k, n, dimension); + + let Some(k) = NonZero::new(k) else { + return clustering; + }; + + let k = NonZero::from(k); + let mut rng = Xoshiro256PlusPlus::seed_from_u64(config.seed); + + // 1. subsample (fit on all of n only when n is already small) + let m = config.sample_cap.max(k.get()).min(n); + + let sample = if m == n { + Cow::Borrowed(x) + } else { + let indices = sample_indices(n, m, &mut rng); + let mut sampled = vec![0.0_f32; m * d]; + + let chunks = sampled.chunks_mut(d); + assert_eq!(chunks.len(), indices.len()); + + for (chunk, index) in chunks.zip(indices) { + chunk.copy_from_slice(&x[index * d..(index + 1) * d]); + } + + Cow::Owned(sampled) + }; + + let sample = sample.as_ref(); + + // Clamping to `n` keeps `row_chunk` from overflowing: `chunk * d` is at most `n * d == + // x.len()`. + let chunk = cmp::min(config.chunk.get(), n); + let row_chunk = chunk * d; + + let sample_inv_norms: Vec = sample + .par_chunks_exact(d) + .map(|point| { + // SAFETY: every point is a `d`-sized row, and `d` is a multiple of 8 (guaranteed by + // Dimension). + let norm = unsafe { kernel::dot(point, point) }.sqrt(); + + if norm > 0.0 { norm.recip() } else { 0.0 } + }) + .collect(); + + // 2. fit on the sample: independent k-means++ restarts in parallel, the + // run with the lowest inertia wins (guards against bad initializations). + // Seeds are pre-derived so the stream matches a sequential run; ties + // break on the restart index, which keeps the winner deterministic no + // matter how rayon schedules the restarts. + let seeds: Vec = core::iter::repeat_with(|| rng.random()) + .take(usize::try_from(config.n_init.get()).unwrap_or(usize::MAX)) + .collect(); + + let best = seeds + .into_par_iter() + .enumerate() + .map(|(index, seed)| { + let mut restart = Restart::new(k, m, d); + let inertia = restart.run(sample, chunk, row_chunk, &sample_inv_norms, seed, config); + + (inertia, index, restart) + }) + .min_by(|lhs, rhs| lhs.0.total_cmp(&rhs.0).then(lhs.1.cmp(&rhs.1))) + .expect("config.n_init is non-zero, so at least one restart ran"); + + // Reuse the winning restart's buffers: its centroids become the result + // and its per-cluster accumulators serve the full-data recomputation, + // instead of allocating fresh ones. + let Restart { + centroids, + mut sums, + mut counts, + .. + } = best.2; + + clustering.centroids = centroids; + + // 3. assign points to clusters + // SAFETY: `x.len() == n * d` (asserted above), `clustering.centroids.len() == k * d`, + // `sums` and `counts` are the restart's `k * d` and `k` sized accumulators, + // and `d` is a multiple of 8 (guaranteed by Dimension). + clustering.inertia = unsafe { + assign( + x, + &mut clustering, + k, + chunk, + row_chunk, + &mut sums, + &mut counts, + ) + }; + + clustering +} + +#[cfg(test)] +mod tests { + #![expect( + clippy::float_cmp, + clippy::integer_division_remainder_used, + reason = "test module: float comparisons are intentional for exact-zero and distance \ + checks; modulo is used in test data construction" + )] + use super::*; + + macro_rules! nz { + ($expr:expr) => { + const { NonZero::new($expr).unwrap() } + }; + } + + /// Builds well-separated blob clusters in D-dimensional space. + /// + /// Each blob has a dominant axis so clusters are far apart in cosine + /// space. Returns `(flat_points, ground_truth_labels)`. + #[expect( + clippy::cast_possible_truncation, + reason = "k is small in tests, fits in u16" + )] + fn make_blobs( + points_per_cluster: usize, + k: usize, + seed: u64, + ) -> (Vec, Vec) { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + let n = points_per_cluster * k; + let mut data = vec![0.0_f32; n * D]; + let mut truth = vec![0_u16; n]; + + for c in 0..k { + let axis = c % D; + for p in 0..points_per_cluster { + let idx = c * points_per_cluster + p; + let row = &mut data[idx * D..(idx + 1) * D]; + + row[axis] = 10.0; + for val in row.iter_mut() { + *val += rng.random_range(-0.01..0.01); + } + + truth[idx] = c as u16; + } + } + + (data, truth) + } + + const D: usize = 64; + + fn l2(v: &[f32]) -> f32 { + v.iter().map(|x| x * x).sum::().sqrt() + } + + /// Random unit-norm centroids in `D`-dimensional space. + fn unit_random(k: NonZero, seed: u64) -> Vec { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); + let mut c = vec![0.0_f32; k.get() * D]; + for row in c.chunks_exact_mut(D) { + for v in row.iter_mut() { + *v = rng.random_range(-1.0..1.0); + } + let n = l2(row); + for v in row.iter_mut() { + *v /= n; + } + } + c + } + + /// Brute-force nearest centroid by cosine similarity. + #[expect(clippy::cast_possible_truncation, reason = "k is small in tests")] + fn brute_nearest_cosine(point: &[f32], centroids: &[f32], k: NonZero) -> u16 { + let pn = l2(point); + let mut best = 0_u16; + let mut best_cos = f32::NEG_INFINITY; + + for c in 0..k.get() { + let cent = ¢roids[c * D..(c + 1) * D]; + let d: f32 = point.iter().zip(cent).map(|(a, b)| a * b).sum(); + let cn = l2(cent); + let cos = if pn == 0.0 || cn == 0.0 { + 0.0 + } else { + d / (pn * cn) + }; + + if cos > best_cos { + best_cos = cos; + best = c as u16; + } + } + best + } + + /// Computes clustering accuracy using majority-vote label mapping. + /// + /// K-means labels are permutation-invariant, so this assigns each + /// predicted cluster to whichever ground-truth cluster it overlaps + /// most, then counts correctly assigned points. + #[expect( + clippy::cast_precision_loss, + reason = "counts are small test values, well within f64 precision" + )] + fn accuracy(predicted: &[u16], truth: &[u16], k: usize) -> f64 { + let mut votes = vec![vec![0_usize; k]; k]; + for (&pred, &true_label) in predicted.iter().zip(truth) { + votes[pred as usize][true_label as usize] += 1; + } + + let correct: usize = votes + .iter() + .map(|row| row.iter().copied().max().unwrap_or(0)) + .sum(); + + correct as f64 / predicted.len() as f64 + } + + /// Shorthand for [`Dimension::new`] that panics on invalid input. + fn dim(d: u16) -> Dimension { + Dimension::new(d).expect("test dimension must be a positive multiple of 8") + } + + #[test] + fn chord_identical_vectors_is_zero() { + // dot=1.0, inv_norm=1.0 => similarity=1 => distance=0 + assert_eq!(squared_chord_distance(1.0, 1.0), 0.0); + } + + #[test] + fn chord_orthogonal_vectors() { + // dot=0 => similarity=0 => distance=2 + let dist = squared_chord_distance(0.0, 1.0); + assert!((dist - 2.0).abs() < 1e-6, "expected 2.0, got {dist}"); + } + + #[test] + fn chord_opposite_vectors() { + // dot=-1.0 => similarity=-1 => distance=4 + let dist = squared_chord_distance(-1.0, 1.0); + assert!((dist - 4.0).abs() < 1e-6, "expected 4.0, got {dist}"); + } + + #[test] + fn chord_zero_norm_returns_zero() { + assert_eq!(squared_chord_distance(0.5, 0.0), 0.0); + assert_eq!(squared_chord_distance(-0.5, 0.0), 0.0); + } + + #[test] + fn chord_is_non_negative() { + for dot_val in [0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0] { + for inv in [0.0, 0.5, 1.0, 2.0] { + let dist = squared_chord_distance(dot_val, inv); + assert!(dist >= 0.0, "negative for dot={dot_val}, inv={inv}: {dist}"); + } + } + } + + #[test] + fn sample_indices_unique_sorted_in_range() { + let rng = Xoshiro256PlusPlus::seed_from_u64(1); + let indices = sample_indices(1000, 100, rng); + + assert_eq!(indices.len(), 100); + assert!( + indices.is_sorted_by(|lhs, rhs| lhs < rhs), + "indices must be strictly increasing (sorted, unique)" + ); + assert!(indices.iter().all(|&index| index < 1000)); + } + + #[test] + fn cluster_empty_input() { + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&[], dim(8), &config); + assert_eq!(result.labels.len(), 0); + assert_eq!(result.centroids.len(), 0); + assert_eq!(result.inertia, 0.0); + } + + #[test] + fn cluster_k0() { + let data = vec![1.0_f32; 8]; + let config = Config::for_k_with_seed(0, 42); + let result = cluster(&data, dim(8), &config); + assert_eq!(result.labels.len(), 1); + assert_eq!(result.labels[0], 0); + assert_eq!(result.centroids.len(), 0); + assert_eq!(result.inertia, 0.0); + } + + #[test] + fn cluster_k1_all_same_label() { + let (data, _) = make_blobs::<8>(20, 3, 123); + let config = Config::for_k_with_seed(1, 42); + let result = cluster(&data, dim(8), &config); + + assert_eq!(result.labels.len(), 60); + assert!( + result.labels.iter().all(|&l| l == 0), + "k=1: all labels must be 0" + ); + } + + #[test] + fn cluster_single_point() { + let data = vec![1.0_f32; 16]; + let config = Config::for_k_with_seed(5, 42); + // k clamped to min(k, n) = 1 + let result = cluster(&data, dim(16), &config); + assert_eq!(result.labels.len(), 1); + assert_eq!(result.labels[0], 0); + } + + #[test] + fn cluster_n_less_than_4() { + // n=3 exercises the scalar tail (no nearest4 tiling). + let (data, _) = make_blobs::<8>(1, 3, 99); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(8), &config); + + assert_eq!(result.labels.len(), 3); + let mut seen = [false; 3]; + for &label in &*result.labels { + seen[label as usize] = true; + } + assert!( + seen.iter().all(|&s| s), + "each point should have a unique cluster" + ); + } + + #[test] + fn cluster_n_equals_k() { + let (data, _) = make_blobs::<8>(1, 5, 77); + let config = Config::for_k_with_seed(5, 42); + let result = cluster(&data, dim(8), &config); + + assert_eq!(result.labels.len(), 5); + let mut seen = [false; 5]; + for &label in &*result.labels { + seen[label as usize] = true; + } + assert!( + seen.iter().all(|&s| s), + "n=k: each point should be its own cluster" + ); + } + + #[test] + fn cluster_recovers_well_separated_blobs() { + let (data, truth) = make_blobs::<8>(50, 4, 314); + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&data, dim(8), &config); + + let acc = accuracy(&result.labels, &truth, 4); + assert!( + acc > 0.95, + "expected >95% accuracy on well-separated blobs, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_deterministic_with_same_seed() { + let (data, _) = make_blobs::<8>(30, 3, 555); + + let r1 = cluster(&data, dim(8), &Config::for_k_with_seed(3, 42)); + let r2 = cluster(&data, dim(8), &Config::for_k_with_seed(3, 42)); + + assert_eq!(r1.labels, r2.labels); + assert_eq!(r1.centroids, r2.centroids); + + // The inertia reduction is a parallel sum, so it is only + // deterministic up to float summation order. + let tolerance = r1.inertia.abs().max(f32::EPSILON) * 1e-5; + assert!( + (r1.inertia - r2.inertia).abs() <= tolerance, + "inertia should agree within summation-order tolerance: {} vs {}", + r1.inertia, + r2.inertia + ); + } + + #[test] + fn cluster_recovers_blobs_across_seeds() { + let (data, truth) = make_blobs::<8>(30, 3, 555); + + for seed in [42, 9999] { + let result = cluster(&data, dim(8), &Config::for_k_with_seed(3, seed)); + let acc = accuracy(&result.labels, &truth, 3); + assert!( + acc > 0.95, + "seed {seed}: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + } + + #[test] + fn cluster_centroids_are_unit_normalized() { + let (data, _) = make_blobs::<8>(40, 4, 222); + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&data, dim(8), &config); + + for c in 0..4_u16 { + let centroid = result.centroid(c); + // SAFETY: centroid has length 8 (= D), a multiple of 8. + let norm = unsafe { kernel::dot(centroid, centroid).sqrt() }; + assert!( + (norm - 1.0).abs() < 1e-5, + "centroid {c} has norm {norm}, expected 1.0" + ); + } + } + + #[test] + fn cluster_labels_in_range() { + let (data, _) = make_blobs::<8>(25, 5, 333); + let config = Config::for_k_with_seed(5, 42); + let result = cluster(&data, dim(8), &config); + + for (i, &label) in result.labels.iter().enumerate() { + assert!(label < 5, "label[{i}] = {label}, expected < 5"); + } + } + + #[test] + fn cluster_labels_nearest_to_assigned_centroid() { + let (data, _) = make_blobs::<8>(30, 3, 444); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(8), &config); + + let k = 3_usize; + let d = 8_usize; + for (i, point) in data.chunks_exact(d).enumerate() { + let assigned = result.labels[i]; + // SAFETY: point and centroid both have length 8 (= D), a multiple of 8. + let assigned_dot = unsafe { kernel::dot(point, result.centroid(assigned)) }; + + #[expect(clippy::cast_possible_truncation, reason = "k=3 fits in u16")] + for c in 0..k as u16 { + // SAFETY: point and centroid both have length 8, a multiple of 8. + let other_dot = unsafe { kernel::dot(point, result.centroid(c)) }; + assert!( + other_dot <= assigned_dot + 1e-5, + "point {i}: assigned to {assigned} (dot={assigned_dot}) but centroid {c} has \ + higher dot={other_dot}" + ); + } + } + } + + #[test] + fn cluster_d32_recovers_blobs() { + let (data, truth) = make_blobs::<32>(40, 3, 888); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(32), &config); + + let acc = accuracy(&result.labels, &truth, 3); + assert!( + acc > 0.95, + "D=32: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_d256_recovers_blobs() { + // Production default dimension (matryoshka truncation target). + let (data, truth) = make_blobs::<256>(50, 4, 1234); + let config = Config::for_k_with_seed(4, 42); + let result = cluster(&data, dim(256), &config); + + let acc = accuracy(&result.labels, &truth, 4); + assert!( + acc > 0.95, + "D=256: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_d1536_recovers_blobs() { + let (data, truth) = make_blobs::<1536>(20, 3, 4321); + let config = Config::for_k_with_seed(3, 42); + let result = cluster(&data, dim(1536), &config); + + let acc = accuracy(&result.labels, &truth, 3); + assert!( + acc > 0.95, + "D=1536: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_chunk_sizes_produce_valid_results() { + let (data, truth) = make_blobs::<8>(50, 4, 314); + + for chunk in [1_usize, 3, 1_000_000] { + let mut config = Config::for_k_with_seed(4, 42); + config.chunk = NonZero::new(chunk).expect("chunk is non-zero"); + let result = cluster(&data, dim(8), &config); + + let acc = accuracy(&result.labels, &truth, 4); + assert!( + acc > 0.95, + "chunk={chunk}: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + } + + #[test] + fn cluster_inertia_reflects_fit_quality() { + let (data, _) = make_blobs::<8>(50, 4, 99); + + let tight = cluster(&data, dim(8), &Config::for_k_with_seed(4, 42)); + assert!(tight.inertia.is_finite()); + assert!(tight.inertia >= 0.0); + + // Forcing 4 well-separated blobs into a single cluster must fit + // strictly worse. + let loose = cluster(&data, dim(8), &Config::for_k_with_seed(1, 42)); + assert!( + loose.inertia > tight.inertia, + "k=1 inertia {} should exceed k=4 inertia {}", + loose.inertia, + tight.inertia + ); + } + + #[test] + fn cluster_recovers_with_subsampling() { + // n=12000 with sample_cap=1024 exercises the Cow::Owned path. + let (data, truth) = make_blobs::<8>(2000, 6, 21); + let mut config = Config::for_k_with_seed(6, 5); + config.sample_cap = 1024; + let result = cluster(&data, dim(8), &config); + + let acc = accuracy(&result.labels, &truth, 6); + assert!( + acc > 0.95, + "subsampled: expected >95% accuracy, got {:.1}%", + acc * 100.0 + ); + } + + #[test] + fn cluster_more_clusters_than_natural_groups() { + // 3 natural groups but k=8: empty clusters keep their seed centroid, + // nothing should be NaN or infinite. + let (data, _) = make_blobs::<8>(400, 3, 31); + let result = cluster(&data, dim(8), &Config::for_k_with_seed(8, 1)); + + assert!( + result.centroids.iter().all(|v| v.is_finite()), + "NaN or infinite centroid" + ); + assert!(result.labels.iter().all(|&l| l < 8)); + } + + #[test] + fn cluster_all_identical_points() { + // Every point identical: D² distances are all zero during seeding, + // which triggers the uniform fallback path. + let n = 100; + let mut data = vec![0.0_f32; n * 8]; + for row in data.chunks_exact_mut(8) { + row[0] = 1.0; + } + let result = cluster(&data, dim(8), &Config::for_k_with_seed(4, 1)); + + assert!(result.centroids.iter().all(|v| v.is_finite())); + assert!(result.labels.iter().all(|&l| l < 4)); + } + + #[test] + fn nearest_centroid_matches_brute_force_cosine() { + let k = nz!(7); + let centroids = unit_random(k, 99); + let mut rng = Xoshiro256PlusPlus::seed_from_u64(100); + + for _ in 0..1000 { + let p: Vec = core::iter::repeat_with(|| rng.random_range(-3.0..3.0)) + .take(D) + .collect(); + let pn = l2(&p); + let inv = if pn > 0.0 { pn.recip() } else { 0.0 }; + + // SAFETY: point has length D=64, centroids has length k*D, + // k > 0, D is a multiple of 8. + let (got, _) = unsafe { nearest_centroid(&p, inv, ¢roids, k, D) }; + assert_eq!( + got, + brute_nearest_cosine(&p, ¢roids, k), + "mismatch for point norm={pn}" + ); + } + } + + #[test] + fn nearest_centroid_argmax_independent_of_inv_norm() { + let k = nz!(5); + let centroids = unit_random(k, 7); + let mut rng = Xoshiro256PlusPlus::seed_from_u64(8); + + for _ in 0..500 { + let p: Vec = core::iter::repeat_with(|| rng.random_range(-2.0..2.0)) + .take(D) + .collect(); + + // SAFETY: point has length D=64, centroids has length k*D, + // k > 0, D is a multiple of 8. + let (a, _) = unsafe { nearest_centroid(&p, 1.0, ¢roids, k, D) }; + // SAFETY: same preconditions. + let (b, _) = unsafe { nearest_centroid(&p, 0.123, ¢roids, k, D) }; + assert_eq!(a, b, "inv_norm must not change the selected centroid"); + } + } + + #[test] + fn cluster_mixed_zero_norm_rows() { + // Some all-zero rows exercise the inv_norm == 0 path in accumulation + // and the squared_chord_distance == 0 return. + let n = 120; + let mut data = vec![0.0_f32; n * 8]; + let mut rng = Xoshiro256PlusPlus::seed_from_u64(7); + for (i, row) in data.chunks_exact_mut(8).enumerate() { + if i % 10 == 0 { + continue; // leave all-zero + } + for v in row.iter_mut() { + *v = rng.random_range(-1.0..1.0); + } + } + let result = cluster(&data, dim(8), &Config::for_k_with_seed(5, 1)); + + assert!(result.centroids.iter().all(|v| v.is_finite())); + assert!(result.labels.iter().all(|&l| l < 5)); + } +} diff --git a/libs/@local/graph/store/src/embedding/dimension.rs b/libs/@local/graph/store/src/embedding/dimension.rs new file mode 100644 index 00000000000..0ba85516ee3 --- /dev/null +++ b/libs/@local/graph/store/src/embedding/dimension.rs @@ -0,0 +1,85 @@ +use core::num::NonZero; + +/// An embedding vector dimension, guaranteed to be a positive multiple of 8. +/// +/// The multiple-of-8 invariant ensures that the dimension evenly divides into +/// SIMD lanes (8×f32 = `f32x8`), so vectorized kernels can operate without +/// remainder handling. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Dimension(NonZero); + +impl Dimension { + /// Creates a new dimension if `value` is non-zero and a multiple of 8. + /// + /// Returns [`None`] otherwise. + #[must_use] + pub const fn new(value: u16) -> Option { + // not using `?` here because it isn't `const` + let Some(value) = NonZero::new(value) else { + return None; + }; + + if !value.get().is_multiple_of(8) { + return None; + } + + Some(Self(value)) + } + + /// The raw dimension value. + #[must_use] + pub const fn get(self) -> u16 { + self.0.get() + } + + /// The raw dimension value as a [`NonZero`]. + #[must_use] + pub const fn value(self) -> NonZero { + self.0 + } +} + +pub const D128: Dimension = Dimension(NonZero::new(128).unwrap()); +pub const D256: Dimension = Dimension(NonZero::new(256).unwrap()); +pub const D512: Dimension = Dimension(NonZero::new(512).unwrap()); +pub const D1536: Dimension = Dimension(NonZero::new(1536).unwrap()); +pub const D3072: Dimension = Dimension(NonZero::new(3072).unwrap()); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_multiples_of_8() { + for v in [8, 16, 24, 128, 256, 3072] { + assert!( + Dimension::new(v).is_some(), + "{v} should be a valid dimension" + ); + } + } + + #[test] + fn zero_rejected() { + assert!(Dimension::new(0).is_none()); + } + + #[test] + fn non_multiples_of_8_rejected() { + for v in [1, 2, 3, 4, 5, 6, 7, 9, 10, 15, 17, 100, 3071] { + assert!( + Dimension::new(v).is_none(), + "{v} should not be a valid dimension" + ); + } + } + + #[test] + fn constants_have_correct_values() { + assert_eq!(D128.0.get(), 128); + assert_eq!(D256.0.get(), 256); + assert_eq!(D512.0.get(), 512); + assert_eq!(D1536.0.get(), 1536); + assert_eq!(D3072.0.get(), 3072); + } +} diff --git a/libs/@local/graph/store/src/embedding/kernel.rs b/libs/@local/graph/store/src/embedding/kernel.rs new file mode 100644 index 00000000000..67f99a66831 --- /dev/null +++ b/libs/@local/graph/store/src/embedding/kernel.rs @@ -0,0 +1,780 @@ +#![expect( + clippy::inline_always, + reason = "while usually discouraged, SIMD operations need to be inlined, as otherwise we \ + spill SIMD registers, see the SIMD documentation." +)] +use core::{ + num::NonZero, + simd::{Simd, f32x8, num::SimdFloat as _}, +}; + +/// Fused multiply-add when the target has native FMA, separate mul+add otherwise. +/// +/// On `aarch64`, FMA is part of the base NEON instruction set (`fmla`). +/// On `x86_64`, FMA requires the `fma` target feature (`vfmadd`); without it, +/// `StdFloat::mul_add` falls back to a per-lane `fmaf` libc call which +/// destroys throughput. The non-FMA path uses a plain multiply and add +/// (`vmulps` + `vaddps`) instead. +#[inline(always)] +#[cfg(not(any(target_arch = "aarch64", target_feature = "fma")))] +fn simd_mul_add(lhs: f32x8, rhs: f32x8, acc: f32x8) -> f32x8 { + lhs * rhs + acc +} + +/// See non-FMA variant above for rationale. +#[inline(always)] +#[cfg(any(target_arch = "aarch64", target_feature = "fma"))] +fn simd_mul_add(lhs: f32x8, rhs: f32x8, acc: f32x8) -> f32x8 { + use std::simd::StdFloat as _; + + lhs.mul_add(rhs, acc) +} + +/// Computes the dot product of two equal-length `f32` slices using SIMD. +/// +/// Four independent accumulators are interleaved to saturate FMA throughput: +/// each accumulator feeds a separate dependency chain, hiding the 4-cycle +/// latency of `fmla`/`vfmadd` on typical micro-architectures. +/// +/// # Safety +/// +/// * `lhs.len() == rhs.len()` +/// * Both lengths are multiples of 8. +#[inline] +#[must_use] +pub unsafe fn dot(lhs: &[f32], rhs: &[f32]) -> f32 { + debug_assert!(lhs.len().is_multiple_of(8) && lhs.len() == rhs.len()); + + // SAFETY: the caller guarantees equal lengths and a multiple of 8. + // These hints let the compiler elide bounds checks in `as_chunks` and + // the subsequent indexing without raw pointer arithmetic. + unsafe { + core::hint::assert_unchecked(lhs.len() == rhs.len()); + core::hint::assert_unchecked(lhs.len().is_multiple_of(8)); + } + + let (lhs, _) = lhs.as_chunks::<8>(); + let (rhs, _) = rhs.as_chunks::<8>(); + + // SAFETY: both original slices have the same length and that length is a + // multiple of 8, so `as_chunks::<8>` produces equal-length chunk slices + // with empty remainders. + unsafe { + core::hint::assert_unchecked(lhs.len() == rhs.len()); + } + + let mut s0 = f32x8::splat(0.0); + let mut s1 = f32x8::splat(0.0); + let mut s2 = f32x8::splat(0.0); + let mut s3 = f32x8::splat(0.0); + + // Unrolled loop: process 4 chunks (32 floats) per iteration. + let mut offset = 0; + while offset + 4 <= lhs.len() { + let Some([l0, l1, l2, l3]) = lhs[offset..offset + 4].as_array() else { + unreachable!() + }; + let Some([r0, r1, r2, r3]) = rhs[offset..offset + 4].as_array() else { + unreachable!() + }; + + s0 = simd_mul_add(Simd::from_slice(l0), Simd::from_slice(r0), s0); + s1 = simd_mul_add(Simd::from_slice(l1), Simd::from_slice(r1), s1); + s2 = simd_mul_add(Simd::from_slice(l2), Simd::from_slice(r2), s2); + s3 = simd_mul_add(Simd::from_slice(l3), Simd::from_slice(r3), s3); + + offset += 4; + } + + // Tail: process remaining 0..3 chunks one at a time. + #[expect(clippy::min_ident_chars)] + while offset < lhs.len() { + let l = &lhs[offset]; + let r = &rhs[offset]; + + s0 = simd_mul_add(Simd::from_slice(l), Simd::from_slice(r), s0); + offset += 1; + } + + (s0 + s1 + s2 + s3).reduce_sum() +} + +/// Scales `value` in-place by `factor`. +/// +/// # Safety +/// +/// * `value.len()` is a multiple of 8. +#[inline] +pub(crate) unsafe fn scale(value: &mut [f32], factor: f32) { + debug_assert!(value.len().is_multiple_of(8)); + + // SAFETY: the caller guarantees a multiple of 8. + unsafe { + core::hint::assert_unchecked(value.len().is_multiple_of(8)); + } + + let factor = f32x8::splat(factor); + let (dst, _) = value.as_chunks_mut::<8>(); + + for dst in dst { + *dst = (f32x8::from_slice(dst) * factor).to_array(); + } +} + +/// Accumulates `src * factor` element-wise into `dst` (`dst += src * factor`). +/// +/// Fuses a scale and add into a single pass, using FMA where available. +/// Avoids the need for a scratch buffer when accumulating normalized vectors. +/// +/// # Safety +/// +/// * `dst.len() == src.len()` +/// * Both lengths are multiples of 8. +#[inline] +pub unsafe fn add_scaled_into(dst: &mut [f32], src: &[f32], factor: f32) { + debug_assert!(dst.len().is_multiple_of(8) && dst.len() == src.len()); + + // SAFETY: the caller guarantees equal lengths and a multiple of 8. + unsafe { + core::hint::assert_unchecked(dst.len() == src.len()); + core::hint::assert_unchecked(dst.len().is_multiple_of(8)); + } + + let factor = f32x8::splat(factor); + let (dst, _) = dst.as_chunks_mut::<8>(); + let (src, _) = src.as_chunks::<8>(); + + // SAFETY: same reasoning as the pre-chunk hints: equal input lengths + // that are multiples of 8 produce equal chunk counts. + unsafe { core::hint::assert_unchecked(dst.len() == src.len()) } + + for index in 0..dst.len() { + let acc = f32x8::from_slice(&dst[index]); + let val = f32x8::from_slice(&src[index]); + dst[index] = simd_mul_add(val, factor, acc).to_array(); + } +} + +/// Normalizes `value` to unit length in-place. +/// +/// If the vector has zero norm, it is left unchanged. +/// +/// # Safety +/// +/// * `value.len()` is a multiple of 8. +#[inline] +pub(crate) unsafe fn normalize(value: &mut [f32]) { + // SAFETY: `dot` requires equal lengths (trivially true, same slice) + // and a multiple of 8 (guaranteed by the caller). + let norm = unsafe { dot(value, value).sqrt() }; + + if norm > 0.0 { + let factor = 1.0 / norm; + // SAFETY: same slice, same length guarantee. + unsafe { + scale(value, factor); + } + } +} + +/// 4 points x 2 centroids. Eight independent accumulators give ILP 8 (enough to +/// saturate FMA throughput); each point chunk feeds 2 FMAs and each centroid +/// chunk feeds 4. Returns `dot[point][centroid]`. +/// +/// Register budget: 8 `f32x8` accumulators. On AVX2 (16 ymm) this leaves room +/// for the 6 operand loads. On NEON each `f32x8` is two 128-bit regs, so the 8 +/// accumulators take 16 of 32 registers; a 4x4 tile (16 accumulators) also fits +/// there if you want more centroid reuse. Either way, check the asm shows the +/// accumulators staying in registers with no stack spills, and that +/// `simd_mul_add` lowered to `vfmadd`/`fmla` and not a `fmaf` call. If the array +/// form ever spills, the manual unroll below is what keeps them in registers. +/// +/// # Safety +/// * all six slices have length `d` +/// * `d` is a multiple of 8 +#[inline(always)] // micro-kernel must inline nearest4 to keep accumulators in registers +#[must_use] +pub unsafe fn micro_4x2( + p0: &[f32], + p1: &[f32], + p2: &[f32], + p3: &[f32], + c0: &[f32], + c1: &[f32], +) -> [[f32; 2]; 4] { + debug_assert!(p0.len().is_multiple_of(8)); + debug_assert!( + [p1.len(), p2.len(), p3.len(), c0.len(), c1.len()] + .iter() + .all(|&l| l == p0.len()) + ); + + let (p0, _) = p0.as_chunks::<8>(); + let (p1, _) = p1.as_chunks::<8>(); + let (p2, _) = p2.as_chunks::<8>(); + let (p3, _) = p3.as_chunks::<8>(); + let (c0, _) = c0.as_chunks::<8>(); + let (c1, _) = c1.as_chunks::<8>(); + + // SAFETY: the caller guarantees all six slices have equal length `d`, + // and `d` is a multiple of 8. The hints let the compiler prove that + // `as_chunks` produces equal-length chunk slices. + unsafe { + core::hint::assert_unchecked(p0.len() == p1.len()); + core::hint::assert_unchecked(p0.len() == p2.len()); + core::hint::assert_unchecked(p0.len() == p3.len()); + core::hint::assert_unchecked(p0.len() == c0.len()); + core::hint::assert_unchecked(p0.len() == c1.len()); + } + + let mut a00 = f32x8::splat(0.0); + let mut a01 = f32x8::splat(0.0); + let mut a10 = f32x8::splat(0.0); + let mut a11 = f32x8::splat(0.0); + let mut a20 = f32x8::splat(0.0); + let mut a21 = f32x8::splat(0.0); + let mut a30 = f32x8::splat(0.0); + let mut a31 = f32x8::splat(0.0); + + for t in 0..c0.len() { + let v0 = Simd::from_array(c0[t]); + let v1 = Simd::from_array(c1[t]); + let x0 = Simd::from_array(p0[t]); + let x1 = Simd::from_array(p1[t]); + let x2 = Simd::from_array(p2[t]); + let x3 = Simd::from_array(p3[t]); + + // super::simd_mul_add picks the FMA arm per target. + a00 = simd_mul_add(x0, v0, a00); + a01 = simd_mul_add(x0, v1, a01); + a10 = simd_mul_add(x1, v0, a10); + a11 = simd_mul_add(x1, v1, a11); + a20 = simd_mul_add(x2, v0, a20); + a21 = simd_mul_add(x2, v1, a21); + a30 = simd_mul_add(x3, v0, a30); + a31 = simd_mul_add(x3, v1, a31); + } + + [ + [a00.reduce_sum(), a01.reduce_sum()], + [a10.reduce_sum(), a11.reduce_sum()], + [a20.reduce_sum(), a21.reduce_sum()], + [a30.reduce_sum(), a31.reduce_sum()], + ] +} + +/// 4 points x 1 centroid: the odd-`k` remainder of [`nearest4`]. Four +/// independent accumulators (one per point) share each centroid load, so the +/// centroid streams through registers once instead of once per point. +/// +/// # Safety +/// +/// * all five slices have length `d` +/// * `d` is a multiple of 8 +#[inline(always)] // micro-kernel must inline into nearest4 to keep accumulators in registers +pub(crate) unsafe fn micro_4x1( + p0: &[f32], + p1: &[f32], + p2: &[f32], + p3: &[f32], + c: &[f32], +) -> [f32; 4] { + debug_assert!(p0.len().is_multiple_of(8)); + debug_assert!( + [p1.len(), p2.len(), p3.len(), c.len()] + .iter() + .all(|&l| l == p0.len()) + ); + + let (p0, _) = p0.as_chunks::<8>(); + let (p1, _) = p1.as_chunks::<8>(); + let (p2, _) = p2.as_chunks::<8>(); + let (p3, _) = p3.as_chunks::<8>(); + let (c, _) = c.as_chunks::<8>(); + + // SAFETY: the caller guarantees all five slices have equal length `d`, + // and `d` is a multiple of 8. The hints let the compiler prove that + // `as_chunks` produces equal-length chunk slices. + unsafe { + core::hint::assert_unchecked(p0.len() == p1.len()); + core::hint::assert_unchecked(p0.len() == p2.len()); + core::hint::assert_unchecked(p0.len() == p3.len()); + core::hint::assert_unchecked(p0.len() == c.len()); + } + + let mut a0 = f32x8::splat(0.0); + let mut a1 = f32x8::splat(0.0); + let mut a2 = f32x8::splat(0.0); + let mut a3 = f32x8::splat(0.0); + + for t in 0..c.len() { + let v = Simd::from_array(c[t]); + + a0 = simd_mul_add(Simd::from_array(p0[t]), v, a0); + a1 = simd_mul_add(Simd::from_array(p1[t]), v, a1); + a2 = simd_mul_add(Simd::from_array(p2[t]), v, a2); + a3 = simd_mul_add(Simd::from_array(p3[t]), v, a3); + } + + [ + a0.reduce_sum(), + a1.reduce_sum(), + a2.reduce_sum(), + a3.reduce_sum(), + ] +} + +/// Finds the nearest centroid for 4 points simultaneously using the +/// [`micro_4x2`] tiled kernel. +/// +/// Returns `(centroid_index, raw_dot_product)` for each of the 4 points. +/// The raw dot product is **not** a distance; and must be converted via +/// using the chord distance formula. +/// +/// # Safety +/// +/// * All four point slices have length `d`. +/// * `centroids.len() >= k * d`. +/// * `d` is a multiple of 8. +#[inline] +#[must_use] +pub unsafe fn nearest4( + p0: &[f32], + p1: &[f32], + p2: &[f32], + p3: &[f32], + centroids: &[f32], + k: NonZero, + d: usize, +) -> [(u16, f32); 4] { + let mut best_dot = [f32::NEG_INFINITY; 4]; + let mut best_idx = [0_u16; 4]; + + // SAFETY: the caller guarantees these preconditions. + unsafe { + core::hint::assert_unchecked(p0.len() == d); + core::hint::assert_unchecked(p0.len() == p1.len()); + core::hint::assert_unchecked(p0.len() == p2.len()); + core::hint::assert_unchecked(p0.len() == p3.len()); + core::hint::assert_unchecked(centroids.len() >= k.get() * d); + core::hint::assert_unchecked(d.is_multiple_of(8)); + } + + let mut j = 0; + while j + 2 <= k.get() { + // SAFETY: `j + 2 <= k` and `centroids.len() >= k * d`, so both + // slices `[j*d .. (j+2)*d]` are in-bounds. + let c0 = unsafe { centroids.get_unchecked(j * d..j * d + d) }; + // SAFETY: see above. + let c1 = unsafe { centroids.get_unchecked((j + 1) * d..(j + 1) * d + d) }; + + // SAFETY: all six slices have length `d`, a multiple of 8. + let dots = unsafe { micro_4x2(p0, p1, p2, p3, c0, c1) }; + + #[expect( + clippy::cast_possible_truncation, + reason = "k originates from Config::k (u16), so j < k fits in u16" + )] + for m in 0..4 { + if dots[m][0] > best_dot[m] { + best_dot[m] = dots[m][0]; + best_idx[m] = j as u16; + } + if dots[m][1] > best_dot[m] { + best_dot[m] = dots[m][1]; + best_idx[m] = (j + 1) as u16; + } + } + j += 2; + } + + // Handle odd k: one remaining centroid via the 4x1 tile. + if j < k.get() { + let c = ¢roids[j * d..j * d + d]; + // SAFETY: all five slices have length `d`, a multiple of 8. + let dots = unsafe { micro_4x1(p0, p1, p2, p3, c) }; + + #[expect( + clippy::cast_possible_truncation, + reason = "k originates from Config::k (u16)" + )] + for m in 0..4 { + if dots[m] > best_dot[m] { + best_dot[m] = dots[m]; + best_idx[m] = j as u16; + } + } + } + + [ + (best_idx[0], best_dot[0]), + (best_idx[1], best_dot[1]), + (best_idx[2], best_dot[2]), + (best_idx[3], best_dot[3]), + ] +} + +#[cfg(test)] +mod tests { + #![expect(clippy::float_cmp, clippy::integer_division_remainder_used)] + + use super::*; + + macro_rules! nz { + ($expr:expr) => { + const { NonZero::new($expr).unwrap() } + }; + } + + /// Scalar dot product for reference. + fn ref_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(x, y)| x * y).sum() + } + + /// Scalar normalize for reference. + fn ref_normalize(v: &mut [f32]) { + let norm = v.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for x in v { + *x /= norm; + } + } + } + + /// Deterministic test vector: entry `i` gets `(i+1) * scale`. + #[expect(clippy::cast_precision_loss)] + fn ramp(len: usize, factor: f32) -> Vec { + (0..len).map(|i| (i + 1) as f32 * factor).collect() + } + + /// Asserts two f32 values are within relative tolerance, with an absolute + /// floor for values near zero. + fn assert_close(a: f32, b: f32, tol: f32) { + let diff = (a - b).abs(); + let denom = a.abs().max(b.abs()).max(1e-12); + assert!( + diff / denom < tol, + "values differ: {a} vs {b} (diff={diff}, rel={})", + diff / denom + ); + } + + #[test] + fn dot_matches_scalar_d8() { + let a = ramp(8, 1.0); + let b = ramp(8, 0.5); + let expected = ref_dot(&a, &b); + // SAFETY: both slices have length 8, a multiple of 8. + let got = unsafe { dot(&a, &b) }; + assert_close(got, expected, 1e-6); + } + + #[test] + fn dot_matches_scalar_d24() { + // 24 = 3 chunks of 8: the 4-unrolled body runs 0 iterations, + // all 3 chunks go through the tail path. + let a = ramp(24, 0.1); + let b = ramp(24, -0.2); + let expected = ref_dot(&a, &b); + // SAFETY: both slices have length 24, a multiple of 8. + let got = unsafe { dot(&a, &b) }; + assert_close(got, expected, 1e-5); + } + + #[test] + fn dot_matches_scalar_d3072() { + let a = ramp(3072, 0.001); + let b = ramp(3072, -0.002); + let expected = ref_dot(&a, &b); + // SAFETY: both slices have length 3072, a multiple of 8. + let got = unsafe { dot(&a, &b) }; + assert_close(got, expected, 1e-4); + } + + #[test] + fn dot_is_commutative() { + let a = ramp(32, 0.3); + let b = ramp(32, -0.7); + // SAFETY: both slices have length 32, a multiple of 8. + let ab = unsafe { dot(&a, &b) }; + // SAFETY: same slices, reversed. + let ba = unsafe { dot(&b, &a) }; + assert_eq!(ab, ba); + } + + #[test] + fn dot_self_is_squared_norm() { + let a = ramp(16, 0.5); + let expected: f32 = a.iter().map(|x| x * x).sum(); + // SAFETY: both arguments are the same 16-element slice. + let got = unsafe { dot(&a, &a) }; + assert_close(got, expected, 1e-6); + } + + #[test] + fn dot_orthogonal_is_zero() { + let mut a = vec![0.0_f32; 8]; + let mut b = vec![0.0_f32; 8]; + a[0] = 1.0; + b[1] = 1.0; + // SAFETY: both slices have length 8. + let got = unsafe { dot(&a, &b) }; + assert_eq!(got, 0.0); + } + + #[test] + fn scale_matches_scalar() { + let mut v = ramp(16, 1.0); + let factor = -0.5; + let expected: Vec = v.iter().map(|x| x * factor).collect(); + // SAFETY: slice has length 16, a multiple of 8. + unsafe { scale(&mut v, factor) } + assert_eq!(v, expected); + } + + #[test] + fn add_scaled_into_matches_separate_ops() { + let src = ramp(16, 1.0); + let factor = 0.3; + let mut dst = ramp(16, 0.5); + let expected: Vec = dst.iter().zip(&src).map(|(d, s)| d + s * factor).collect(); + + // SAFETY: both slices have length 16, a multiple of 8. + unsafe { add_scaled_into(&mut dst, &src, factor) } + + for (&got, &exp) in dst.iter().zip(&expected) { + assert_close(got, exp, 1e-6); + } + } + + #[test] + fn add_scaled_into_factor_zero_is_identity() { + let src = ramp(8, 100.0); + let mut dst = ramp(8, 1.0); + let original = dst.clone(); + // SAFETY: both slices have length 8, a multiple of 8. + unsafe { add_scaled_into(&mut dst, &src, 0.0) } + assert_eq!(dst, original); + } + + #[test] + fn normalize_produces_unit_norm() { + let mut v = ramp(32, 0.7); + // SAFETY: length 32, a multiple of 8. + unsafe { normalize(&mut v) } + // SAFETY: same slice, length unchanged. + let norm = unsafe { dot(&v, &v).sqrt() }; + assert_close(norm, 1.0, 1e-6); + } + + #[test] + fn normalize_preserves_direction() { + let mut v = ramp(16, 2.0); + let mut ref_v = v.clone(); + ref_normalize(&mut ref_v); + // SAFETY: length 16, a multiple of 8. + unsafe { normalize(&mut v) } + for (&a, &b) in v.iter().zip(&ref_v) { + assert_close(a, b, 1e-6); + } + } + + #[test] + fn normalize_zero_vector_unchanged() { + let mut v = vec![0.0_f32; 8]; + // SAFETY: length 8, a multiple of 8. + unsafe { normalize(&mut v) } + assert!(v.iter().all(|&x| x == 0.0)); + } + + #[test] + fn normalize_already_unit_is_stable() { + let mut v = vec![0.0_f32; 8]; + v[0] = 1.0; + // SAFETY: length 8, a multiple of 8. + unsafe { normalize(&mut v) } + assert_close(v[0], 1.0, 1e-7); + assert!(v[1..].iter().all(|&x| x == 0.0)); + } + + #[test] + fn micro_4x2_matches_individual_dots() { + let d = 16; + let p0 = ramp(d, 0.1); + let p1 = ramp(d, -0.2); + let p2 = ramp(d, 0.3); + let p3 = ramp(d, -0.4); + let c0 = ramp(d, 0.5); + let c1 = ramp(d, -0.6); + + // SAFETY: all 6 slices have length 16, a multiple of 8. + let got = unsafe { micro_4x2(&p0, &p1, &p2, &p3, &c0, &c1) }; + + let expected = [ + [ref_dot(&p0, &c0), ref_dot(&p0, &c1)], + [ref_dot(&p1, &c0), ref_dot(&p1, &c1)], + [ref_dot(&p2, &c0), ref_dot(&p2, &c1)], + [ref_dot(&p3, &c0), ref_dot(&p3, &c1)], + ]; + + for (g, e) in got.iter().zip(&expected) { + assert_close(g[0], e[0], 1e-5); + assert_close(g[1], e[1], 1e-5); + } + } + + #[test] + fn micro_4x2_d3072() { + let d = 3072; + let p0 = ramp(d, 0.001); + let p1 = ramp(d, -0.001); + let p2 = ramp(d, 0.002); + let p3 = ramp(d, -0.002); + let c0 = ramp(d, 0.001); + let c1 = ramp(d, -0.001); + + // SAFETY: all 6 slices have length 3072, a multiple of 8. + let got = unsafe { micro_4x2(&p0, &p1, &p2, &p3, &c0, &c1) }; + + let expected = [ + [ref_dot(&p0, &c0), ref_dot(&p0, &c1)], + [ref_dot(&p1, &c0), ref_dot(&p1, &c1)], + [ref_dot(&p2, &c0), ref_dot(&p2, &c1)], + [ref_dot(&p3, &c0), ref_dot(&p3, &c1)], + ]; + + for (g, e) in got.iter().zip(&expected) { + assert_close(g[0], e[0], 1e-3); + assert_close(g[1], e[1], 1e-3); + } + } + + #[test] + fn micro_4x1_matches_individual_dots() { + let d = 16; + let p0 = ramp(d, 0.1); + let p1 = ramp(d, -0.2); + let p2 = ramp(d, 0.3); + let p3 = ramp(d, -0.4); + let c = ramp(d, 0.5); + + // SAFETY: all 5 slices have length 16, a multiple of 8. + let got = unsafe { micro_4x1(&p0, &p1, &p2, &p3, &c) }; + + let expected = [ + ref_dot(&p0, &c), + ref_dot(&p1, &c), + ref_dot(&p2, &c), + ref_dot(&p3, &c), + ]; + + for (g, e) in got.iter().zip(&expected) { + assert_close(*g, *e, 1e-5); + } + } + + #[test] + fn micro_4x1_d3072() { + let d = 3072; + let p0 = ramp(d, 0.001); + let p1 = ramp(d, -0.001); + let p2 = ramp(d, 0.002); + let p3 = ramp(d, -0.002); + let c = ramp(d, 0.001); + + // SAFETY: all 5 slices have length 3072, a multiple of 8. + let got = unsafe { micro_4x1(&p0, &p1, &p2, &p3, &c) }; + + let expected = [ + ref_dot(&p0, &c), + ref_dot(&p1, &c), + ref_dot(&p2, &c), + ref_dot(&p3, &c), + ]; + + for (g, e) in got.iter().zip(&expected) { + assert_close(*g, *e, 1e-3); + } + } + + #[test] + fn nearest4_matches_brute_force_even_k() { + let d = 8; + let k = nz!(4); + + // 4 centroids: axis-aligned unit vectors. + let mut centroids = vec![0.0_f32; k.get() * d]; + for i in 0..k.get() { + centroids[i * d + i] = 1.0; + } + + // 4 points, each close to a different centroid. + let mut points: [Vec; 4] = core::array::from_fn(|_| vec![0.0_f32; d]); + for i in 0..4 { + points[i][i] = 10.0; + points[i][(i + 1) % d] = 0.1; + } + + // SAFETY: d=8 (multiple of 8), k=4 > 0, centroids has length k*d, + // all point slices have length d. + let got = unsafe { + nearest4( + &points[0], &points[1], &points[2], &points[3], ¢roids, k, d, + ) + }; + + assert_eq!(got[0].0, 0); + assert_eq!(got[1].0, 1); + assert_eq!(got[2].0, 2); + assert_eq!(got[3].0, 3); + } + + #[test] + fn nearest4_matches_brute_force_odd_k() { + let d = 8; + let k = nz!(3); // odd: exercises the remainder path + + let mut centroids = vec![0.0_f32; k.get() * d]; + for i in 0..k.get() { + centroids[i * d + i] = 1.0; + } + + let mut points: [Vec; 4] = core::array::from_fn(|_| vec![0.0_f32; d]); + points[0][0] = 5.0; + points[1][1] = 5.0; + points[2][2] = 5.0; + points[3][0] = 3.0; // closest to centroid 0 + + // SAFETY: d=8 (multiple of 8), k=3 > 0, centroids has length k*d, + // all point slices have length d. + let got = unsafe { + nearest4( + &points[0], &points[1], &points[2], &points[3], ¢roids, k, d, + ) + }; + + assert_eq!(got[0].0, 0); + assert_eq!(got[1].0, 1); + assert_eq!(got[2].0, 2); + assert_eq!(got[3].0, 0); + } + + #[test] + fn nearest4_k1_all_same() { + let d = 8; + let centroids = ramp(d, 1.0); + let p0 = ramp(d, 0.1); + let p1 = ramp(d, -0.2); + let p2 = ramp(d, 0.3); + let p3 = ramp(d, -0.4); + + // SAFETY: d=8 (multiple of 8), k=1 > 0, centroids has length d, + // all point slices have length d. + let got = unsafe { nearest4(&p0, &p1, &p2, &p3, ¢roids, nz!(1), d) }; + + assert_eq!(got[0].0, 0); + assert_eq!(got[1].0, 0); + assert_eq!(got[2].0, 0); + assert_eq!(got[3].0, 0); + } +} diff --git a/libs/@local/graph/store/src/embedding/mod.rs b/libs/@local/graph/store/src/embedding/mod.rs new file mode 100644 index 00000000000..c370176aacc --- /dev/null +++ b/libs/@local/graph/store/src/embedding/mod.rs @@ -0,0 +1,16 @@ +#![expect( + unsafe_code, + clippy::indexing_slicing, + clippy::float_arithmetic, + clippy::min_ident_chars, + clippy::many_single_char_names, + reason = "embedding module is under active development. Single-char idents (k, n, m, d, x) \ + are standard mathematical notation for clustering." +)] + +pub mod clustering; +pub mod dimension; +// Hidden from docs: the kernel is an implementation detail, exposed only so +// the `embedding` bench target can measure it in isolation. +#[doc(hidden)] +pub mod kernel; diff --git a/libs/@local/graph/store/src/entity/mod.rs b/libs/@local/graph/store/src/entity/mod.rs index f1174879703..e9aa8161500 100644 --- a/libs/@local/graph/store/src/entity/mod.rs +++ b/libs/@local/graph/store/src/entity/mod.rs @@ -4,14 +4,15 @@ pub use self::{ EntityQuerySortingToken, EntityQueryToken, }, store::{ - ClosedMultiEntityTypeMap, CreateEntityParams, DeleteEntitiesParams, DeletionScope, - DeletionSummary, DiffEntityParams, DiffEntityResult, EntityPermissions, EntityStore, - EntityValidationType, HasPermissionForEntitiesParams, LinkDeletionBehavior, - PatchEntityParams, QueryConversion, QueryEntitiesParams, QueryEntitiesResponse, - QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SearchEntitiesFilter, - SearchEntitiesParams, SearchEntitiesResponse, SummarizeEntitiesParams, - SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityComponents, - ValidateEntityError, ValidateEntityParams, + ClosedMultiEntityTypeMap, ClusterEntitiesParams, ClusterEntitiesResponse, + CreateEntityParams, DeleteEntitiesParams, DeletionScope, DeletionSummary, DiffEntityParams, + DiffEntityResult, EntityCluster, EntityPermissions, EntityStore, EntityValidationType, + HasPermissionForEntitiesParams, LinkDeletionBehavior, PatchEntityParams, QueryConversion, + QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, + QueryEntitySubgraphResponse, SearchEntitiesFilter, SearchEntitiesParams, + SearchEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, + UpdateEntityEmbeddingsParams, ValidateEntityComponents, ValidateEntityError, + ValidateEntityParams, }, validation_report::{ EmptyEntityTypes, EntityRetrieval, EntityTypeRetrieval, EntityTypesError, diff --git a/libs/@local/graph/store/src/entity/store.rs b/libs/@local/graph/store/src/entity/store.rs index 94ff152a36b..8bf5610f40b 100644 --- a/libs/@local/graph/store/src/entity/store.rs +++ b/libs/@local/graph/store/src/entity/store.rs @@ -1,4 +1,5 @@ use alloc::borrow::Cow; +use core::num::NonZero; use std::collections::{HashMap, HashSet}; use error_stack::Report; @@ -36,7 +37,9 @@ use utoipa::{ use crate::{ entity::{EntityQueryCursor, EntityQuerySorting, EntityValidationReport}, entity_type::{EntityTypeResolveDefinitions, IncludeEntityTypeOption}, - error::{CheckPermissionError, DeletionError, InsertionError, QueryError, UpdateError}, + error::{ + CheckPermissionError, ClusterError, DeletionError, InsertionError, QueryError, UpdateError, + }, filter::{Filter, SemanticDistance}, subgraph::{ Subgraph, @@ -525,6 +528,61 @@ impl PatchEntityParams { } } +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct ClusterEntitiesParams { + pub entity_ids: Vec, + /// Desired number of clusters. Clamped to the number of entities with + /// embeddings when that is smaller. + pub cluster_count: u16, + /// Embedding dimension after matryoshka truncation. Must be a positive + /// multiple of 8; values above 3072 are rejected. Defaults to 256. + #[serde(default = "ClusterEntitiesParams::default_dimension")] + #[cfg_attr(feature = "utoipa", schema(value_type = u16, minimum = 1, default = 256, example = 256))] + pub dimension: NonZero, + + /// Seed for the random number generator used in clustering. + /// + /// If not provided, a random seed will be used. + pub seed: Option, +} + +impl ClusterEntitiesParams { + const fn default_dimension() -> NonZero { + const { NonZero::new(256).unwrap() } + } +} + +/// One cluster from a spherical k-means run over entity embeddings. +#[derive(Debug, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] +pub struct EntityCluster { + /// Index in `0..cluster_count`. + pub cluster_id: u16, + pub entity_ids: Vec, + /// Unit-normalized centroid with length equal to the requested dimension. + pub centroid: Vec, +} + +/// Result of [`EntityStore::cluster_entities`]. +#[derive(Debug, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] +pub struct ClusterEntitiesResponse { + /// One entry per non-empty cluster. Empty clusters (no points assigned) + /// are omitted. + pub clusters: Vec, + /// Entities from the request that had no stored embedding. + pub missing_embeddings: Vec, + /// Sum of squared chord distances from every clustered entity to its + /// assigned centroid. Lower is tighter; comparable across runs over the + /// same entities, e.g. to choose a cluster count. `0.0` when nothing was + /// clustered. + pub inertia: f32, +} + #[derive(Debug, Deserialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] #[serde(rename_all = "camelCase", deny_unknown_fields)] @@ -912,6 +970,27 @@ pub trait EntityStore { params: UpdateEntityEmbeddingsParams<'_>, ) -> impl Future>> + Send; + /// Groups entities by embedding similarity using spherical k-means. + /// + /// Each entity's combined embedding is truncated to the requested + /// dimension (matryoshka encoding) before clustering. The returned + /// centroids are unit-normalized and have the same dimension. + /// + /// Entities without a stored embedding are not clustered; they appear + /// in [`ClusterEntitiesResponse::missing_embeddings`]. + /// + /// # Errors + /// + /// Returns [`ClusterError::InvalidDimension`] if the dimension is not a + /// positive multiple of 8, [`ClusterError::DimensionTooLarge`] if it + /// exceeds the stored embedding width, or [`ClusterError::Store`] if the + /// embedding query fails. + fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: ClusterEntitiesParams, + ) -> impl Future>> + Send; + /// Re-indexes the cache for entities. /// /// This is only needed if the entity was changed in place without an update procedure. This is diff --git a/libs/@local/graph/store/src/error.rs b/libs/@local/graph/store/src/error.rs index d9f39229103..3f1b684167e 100644 --- a/libs/@local/graph/store/src/error.rs +++ b/libs/@local/graph/store/src/error.rs @@ -1,4 +1,4 @@ -use core::{error::Error, fmt}; +use core::{error::Error, fmt, num::NonZero}; #[derive(Debug)] #[must_use] @@ -68,3 +68,18 @@ pub enum CheckPermissionError { } impl Error for CheckPermissionError {} + +/// Failure to cluster entities by embedding similarity. +#[derive(Debug, derive_more::Display)] +#[display("Could not cluster entities: {_variant}")] +#[must_use] +pub enum ClusterError { + #[display("dimension {dimension} is not a positive multiple of 8")] + InvalidDimension { dimension: NonZero }, + #[display("dimension {dimension} exceeds stored embedding dimension {max}")] + DimensionTooLarge { dimension: NonZero, max: u16 }, + #[display("embedding query failed")] + Store, +} + +impl Error for ClusterError {} diff --git a/libs/@local/graph/store/src/lib.rs b/libs/@local/graph/store/src/lib.rs index 2931d3f0eca..668c46b20b8 100644 --- a/libs/@local/graph/store/src/lib.rs +++ b/libs/@local/graph/store/src/lib.rs @@ -5,6 +5,10 @@ #![feature( // Language Features impl_trait_in_assoc_type, + + // Library Features, + portable_simd, + integer_widen_truncate, )] #![cfg_attr(test, feature( // Language Features @@ -23,6 +27,7 @@ pub mod oauth_provider; pub mod property_type; pub mod user_deletion; +pub mod embedding; pub mod error; pub mod filter; pub mod migration; diff --git a/libs/@local/graph/type-fetcher/src/store.rs b/libs/@local/graph/type-fetcher/src/store.rs index 99a3ad3415c..87a124b74b0 100644 --- a/libs/@local/graph/type-fetcher/src/store.rs +++ b/libs/@local/graph/type-fetcher/src/store.rs @@ -35,9 +35,9 @@ use hash_graph_store::{ UnarchiveDataTypeParams, UpdateDataTypeEmbeddingParams, UpdateDataTypesParams, }, entity::{ - CreateEntityParams, DeleteEntitiesParams, DeletionSummary, EntityStore, - EntityValidationReport, HasPermissionForEntitiesParams, PatchEntityParams, - QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, + ClusterEntitiesParams, ClusterEntitiesResponse, CreateEntityParams, DeleteEntitiesParams, + DeletionSummary, EntityStore, EntityValidationReport, HasPermissionForEntitiesParams, + PatchEntityParams, QueryEntitiesParams, QueryEntitiesResponse, QueryEntitySubgraphParams, QueryEntitySubgraphResponse, SearchEntitiesParams, SearchEntitiesResponse, SummarizeEntitiesParams, SummarizeEntitiesResponse, UpdateEntityEmbeddingsParams, ValidateEntityParams, @@ -50,7 +50,9 @@ use hash_graph_store::{ QueryEntityTypesResponse, SearchEntityTypesParams, SearchEntityTypesResponse, UnarchiveEntityTypeParams, UpdateEntityTypeEmbeddingParams, UpdateEntityTypesParams, }, - error::{CheckPermissionError, DeletionError, InsertionError, QueryError, UpdateError}, + error::{ + CheckPermissionError, ClusterError, DeletionError, InsertionError, QueryError, UpdateError, + }, filter::{Filter, QueryRecord}, pool::StorePool, property_type::{ @@ -1713,6 +1715,14 @@ where self.store.update_entity_embeddings(actor_id, params).await } + async fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: ClusterEntitiesParams, + ) -> Result> { + self.store.cluster_entities(actor_id, params).await + } + async fn reindex_entity_cache(&mut self) -> Result<(), Report> { self.store.reindex_entity_cache().await } diff --git a/tests/graph/integration/postgres/lib.rs b/tests/graph/integration/postgres/lib.rs index 5c2ab899e3f..f9e54f423bb 100644 --- a/tests/graph/integration/postgres/lib.rs +++ b/tests/graph/integration/postgres/lib.rs @@ -890,6 +890,17 @@ impl EntityStore for DatabaseApi<'_> { self.store.reindex_entity_cache().await } + async fn cluster_entities( + &self, + actor_id: ActorEntityUuid, + params: hash_graph_store::entity::ClusterEntitiesParams, + ) -> Result< + hash_graph_store::entity::ClusterEntitiesResponse, + Report, + > { + self.store.cluster_entities(actor_id, params).await + } + async fn has_permission_for_entities( &self, authenticated_actor: AuthenticatedActor,