diff --git a/libs/@local/graph/api/openapi/openapi.json b/libs/@local/graph/api/openapi/openapi.json index 6fa9755b618..ba069fe91d5 100644 --- a/libs/@local/graph/api/openapi/openapi.json +++ b/libs/@local/graph/api/openapi/openapi.json @@ -1648,16 +1648,6 @@ "$ref": "#/components/schemas/ActorEntityUuid" } }, - { - "name": "Interactive", - "in": "header", - "description": "Whether the request is used interactively", - "required": false, - "schema": { - "type": "boolean", - "nullable": true - } - }, { "name": "after", "in": "query", @@ -1727,16 +1717,6 @@ "$ref": "#/components/schemas/ActorEntityUuid" } }, - { - "name": "Interactive", - "in": "header", - "description": "Whether the query is interactive", - "required": false, - "schema": { - "type": "boolean", - "nullable": true - } - }, { "name": "after", "in": "query", @@ -4686,60 +4666,6 @@ "type": "object" } }, - "EntityQueryOptions": { - "type": "object", - "required": [ - "temporalAxes", - "includeDrafts", - "includePermissions" - ], - "properties": { - "conversions": { - "type": "array", - "items": { - "$ref": "#/components/schemas/QueryConversion" - } - }, - "cursor": { - "allOf": [ - { - "$ref": "#/components/schemas/EntityQueryCursor" - } - ], - "nullable": true - }, - "includeDrafts": { - "type": "boolean" - }, - "includeEntityTypes": { - "allOf": [ - { - "$ref": "#/components/schemas/IncludeEntityTypeOption" - } - ], - "nullable": true - }, - "includePermissions": { - "type": "boolean" - }, - "limit": { - "type": "integer", - "nullable": true, - "minimum": 0 - }, - "sortingPaths": { - "type": "array", - "items": { - "$ref": "#/components/schemas/EntityQuerySortingRecord" - }, - "nullable": true - }, - "temporalAxes": { - "$ref": "#/components/schemas/QueryTemporalAxesUnresolved" - } - }, - "additionalProperties": false - }, "EntityQuerySortingPath": { "type": "array", "items": { @@ -8042,42 +7968,80 @@ } }, "QueryEntitiesRequest": { - "oneOf": [ - { + "type": "object", + "required": [ + "filter", + "temporalAxes", + "includeDrafts", + "includePermissions" + ], + "properties": { + "conversions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/QueryConversion" + } + }, + "cursor": { "allOf": [ { - "$ref": "#/components/schemas/EntityQueryOptions" - }, - { - "type": "object", - "required": [ - "query" - ], - "properties": { - "query": {} - } + "$ref": "#/components/schemas/EntityQueryCursor" } - ] + ], + "nullable": true }, - { + "filter": { + "$ref": "#/components/schemas/Filter" + }, + "includeCount": { + "type": "boolean" + }, + "includeCreatedByIds": { + "type": "boolean" + }, + "includeDrafts": { + "type": "boolean" + }, + "includeEditionCreatedByIds": { + "type": "boolean" + }, + "includeEntityTypes": { "allOf": [ { - "$ref": "#/components/schemas/EntityQueryOptions" - }, - { - "type": "object", - "required": [ - "filter" - ], - "properties": { - "filter": { - "$ref": "#/components/schemas/Filter" - } - } + "$ref": "#/components/schemas/IncludeEntityTypeOption" } - ] + ], + "nullable": true + }, + "includePermissions": { + "type": "boolean" + }, + "includeTypeIds": { + "type": "boolean" + }, + "includeTypeTitles": { + "type": "boolean" + }, + "includeWebIds": { + "type": "boolean" + }, + "limit": { + "type": "integer", + "nullable": true, + "minimum": 0 + }, + "sortingPaths": { + "type": "array", + "items": { + "$ref": "#/components/schemas/EntityQuerySortingRecord" + }, + "nullable": true + }, + "temporalAxes": { + "$ref": "#/components/schemas/QueryTemporalAxesUnresolved" } - ] + }, + "additionalProperties": false }, "QueryEntitiesResponse": { "type": "object", @@ -8125,12 +8089,11 @@ { "allOf": [ { - "$ref": "#/components/schemas/EntityQueryOptions" + "$ref": "#/components/schemas/QueryEntitiesRequest" }, { "type": "object", "required": [ - "query", "traversalPaths", "graphResolveDepths" ], @@ -8138,7 +8101,6 @@ "graphResolveDepths": { "$ref": "#/components/schemas/GraphResolveDepths" }, - "query": {}, "traversalPaths": { "type": "array", "items": { @@ -8152,7 +8114,7 @@ { "allOf": [ { - "$ref": "#/components/schemas/EntityQueryOptions" + "$ref": "#/components/schemas/QueryEntitiesRequest" }, { "type": "object", @@ -8209,13 +8171,9 @@ { "type": "object", "required": [ - "filter", "traversalPaths" ], "properties": { - "filter": { - "$ref": "#/components/schemas/Filter" - }, "traversalPaths": { "type": "array", "items": { diff --git a/libs/@local/graph/api/src/rest/entity.rs b/libs/@local/graph/api/src/rest/entity/mod.rs similarity index 59% rename from libs/@local/graph/api/src/rest/entity.rs rename to libs/@local/graph/api/src/rest/entity/mod.rs index dcc4846f143..17e12fb342e 100644 --- a/libs/@local/graph/api/src/rest/entity.rs +++ b/libs/@local/graph/api/src/rest/entity/mod.rs @@ -1,5 +1,7 @@ //! Web routes for CRU operations on entities. +pub mod query; + use alloc::sync::Arc; use std::collections::HashMap; @@ -10,18 +12,16 @@ use hash_graph_postgres_store::store::error::{EntityDoesNotExist, RaceConditionO use hash_graph_store::{ self, entity::{ - ClosedMultiEntityTypeMap, CreateEntityParams, DiffEntityParams, DiffEntityResult, - EntityPermissions, EntityQueryCursor, EntityQuerySortingRecord, EntityQuerySortingToken, - EntityQueryToken, EntityStore, EntityTypesError, EntityValidationReport, - EntityValidationType, HasPermissionForEntitiesParams, LinkDataStateError, - LinkDataValidationReport, LinkError, LinkTargetError, LinkValidationReport, - LinkedEntityError, MetadataValidationReport, PatchEntityParams, + ClosedMultiEntityTypeMap, CountEntitiesParams, CreateEntityParams, DiffEntityParams, + DiffEntityResult, EntityPermissions, EntityQueryCursor, EntityQuerySortingRecord, + EntityQuerySortingToken, EntityQueryToken, EntityStore, EntityTypesError, + EntityValidationReport, EntityValidationType, HasPermissionForEntitiesParams, + LinkDataStateError, LinkDataValidationReport, LinkError, LinkTargetError, + LinkValidationReport, LinkedEntityError, MetadataValidationReport, PatchEntityParams, PropertyMetadataValidationReport, QueryConversion, QueryEntitiesResponse, - SearchEntitiesFilter, SearchEntitiesResponse, SummarizeEntitiesParams, - SummarizeEntitiesResponse, UnexpectedEntityType, UpdateEntityEmbeddingsParams, - ValidateEntityComponents, ValidateEntityParams, + UnexpectedEntityType, UpdateEntityEmbeddingsParams, ValidateEntityComponents, + ValidateEntityParams, }, - entity_type::EntityTypeResolveDefinitions, pool::StorePool, query::{NullOrdering, Ordering}, }; @@ -41,9 +41,7 @@ use hash_graph_types::{ }, }; use hash_temporal_client::TemporalClient; -use hashql_core::heap::Heap; -use serde::{Deserialize as _, Serialize}; -use serde_json::value::RawValue as RawJsonvalue; +use serde::Deserialize as _; use type_system::{ knowledge::{ Confidence, Entity, Property, @@ -67,35 +65,30 @@ use type_system::{ }, value::{ValueMetadata, metadata::ValueProvenance}, }, - ontology::VersionedUrl, principal::actor::ActorType, provenance::{Location, OriginProvenance, SourceProvenance, SourceType}, }; -use utoipa::{OpenApi, ToSchema}; -pub use crate::rest::entity_query_request::{ - EntityQuery, EntityQueryOptions, QueryEntitiesRequest, QueryEntitySubgraphRequest, - SearchEntitiesRequest, +use self::query::{ + QueryEntitySubgraphResponse, count_entities, query_entities, query_entity_subgraph, + request::{QueryEntitiesRequest, QueryEntitySubgraphRequest}, }; use crate::rest::{ - ApiConfig, AuthenticatedUserHeader, InteractiveHeader, OpenApiQuery, QueryLogger, - entity_query_request::CompilationOptions, + AuthenticatedUserHeader, OpenApiQuery, QueryLogger, json::Json, status::{BoxedResponse, report_to_response}, - utoipa_typedef::subgraph::Subgraph, }; -#[derive(OpenApi)] +#[derive(utoipa::OpenApi)] #[openapi( paths( create_entity, create_entities, validate_entity, has_permission_for_entities, - query_entities, - query_entity_subgraph, - search_entities, - summarize_entities, + self::query::query_entities, + self::query::query_entity_subgraph, + self::query::count_entities, patch_entity, update_entity_embeddings, diff_entity, @@ -108,8 +101,7 @@ use crate::rest::{ PropertyArrayWithMetadata, PropertyObjectWithMetadata, ValidateEntityParams, - SummarizeEntitiesParams, - SummarizeEntitiesResponse, + CountEntitiesParams, EntityValidationType, ValidateEntityComponents, Embedding, @@ -122,12 +114,8 @@ use crate::rest::{ HasPermissionForEntitiesParams, - EntityQueryOptions, QueryEntitiesRequest, QueryEntitySubgraphRequest, - SearchEntitiesRequest, - SearchEntitiesFilter, - SearchEntitiesResponse, EntityQueryCursor, Ordering, NullOrdering, @@ -230,13 +218,12 @@ impl EntityResource { .route("/validate", post(validate_entity::)) .route("/embeddings", post(update_entity_embeddings::)) .route("/permissions", post(has_permission_for_entities::)) - .route("/search", post(search_entities::)) .nest( "/query", Router::new() .route("/", post(query_entities::)) .route("/subgraph", post(query_entity_subgraph::)) - .route("/summarize", post(summarize_entities::)), + .route("/count", post(count_entities::)), ), ) } @@ -408,290 +395,6 @@ where .map_err(report_to_response) } -#[utoipa::path( - post, - path = "/entities/query", - request_body = QueryEntitiesRequest, - tag = "Entity", - params( - ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), - ("Interactive" = Option, Header, description = "Whether the request is used interactively"), - ("after" = Option, Query, description = "The cursor to start reading from"), - ("limit" = Option, Query, description = "The maximum number of entities to read"), - ), - responses( - ( - status = 200, - content_type = "application/json", - body = QueryEntitiesResponse, - description = "A list of entities that satisfy the given query.", - ), - (status = 422, content_type = "text/plain", description = "Provided query is invalid"), - (status = 500, description = "Store error occurred"), - ) -)] -async fn query_entities( - AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, - InteractiveHeader(interactive): InteractiveHeader, - store_pool: Extension>, - temporal_client: Extension>>, - Extension(api_config): Extension, - mut query_logger: Option>, - Json(request): Json>, -) -> Result>, BoxedResponse> -where - S: StorePool + Send + Sync, -{ - if let Some(query_logger) = &mut query_logger { - query_logger.capture(actor_id, OpenApiQuery::GetEntities(&request)); - } - - let store = store_pool - .acquire(temporal_client.0) - .await - .map_err(report_to_response)?; - - let request = QueryEntitiesRequest::deserialize(&*request) - .map_err(Report::from) - .attach(hash_status::StatusCode::InvalidArgument) - .map_err(report_to_response)?; - - let (query, options) = request.into_parts(); - - // TODO: https://linear.app/hash/issue/H-5351/reuse-parts-between-compilation-units - let mut heap = Heap::uninitialized(); - - if matches!(query, EntityQuery::Query { .. }) { - // The heap is going to be used in the compilation of the query and therefore needs to be - // primed. - // Doing this in a separate step allows us to be allocation free when not using HashQL - // queries. - heap.prime(); - } - - let filter = query.compile(&heap, CompilationOptions { interactive })?; - - let params = options - .into_params(filter, api_config) - .attach(hash_status::StatusCode::InvalidArgument) - .map_err(report_to_response)?; - - let response = store - .query_entities(actor_id, params) - .await - .map(Json) - .map_err(report_to_response); - - if let Some(query_logger) = &mut query_logger { - query_logger.send().await.map_err(report_to_response)?; - } - response -} - -#[utoipa::path( - post, - path = "/entities/search", - request_body = SearchEntitiesRequest, - tag = "Entity", - params( - ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), - ), - responses( - ( - status = 200, - content_type = "application/json", - body = SearchEntitiesResponse, - description = "Entities ordered by ascending cosine distance to the query embedding.", - ), - (status = 400, content_type = "text/plain", description = "Provided request body is invalid"), - (status = 500, description = "Store error occurred"), - ) -)] -async fn search_entities( - AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, - store_pool: Extension>, - temporal_client: Extension>>, - Extension(api_config): Extension, - Json(request): Json, -) -> Result, BoxedResponse> -where - S: StorePool + Send + Sync, -{ - let store = store_pool - .acquire(temporal_client.0) - .await - .map_err(report_to_response)?; - - let params = request - .into_params(api_config) - .attach(hash_status::StatusCode::InvalidArgument) - .map_err(report_to_response)?; - - store - .search_entities(actor_id, params) - .await - .map(Json) - .map_err(report_to_response) -} - -#[derive(Serialize, ToSchema)] -#[serde(rename_all = "camelCase")] -struct QueryEntitySubgraphResponse<'r> { - subgraph: Subgraph, - #[serde(borrow)] - cursor: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(nullable = false)] - closed_multi_entity_types: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(nullable = false)] - definitions: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(nullable = false)] - entity_permissions: Option>, -} - -#[utoipa::path( - post, - path = "/entities/query/subgraph", - request_body = QueryEntitySubgraphRequest, - tag = "Entity", - params( - ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), - ("Interactive" = Option, Header, description = "Whether the query is interactive"), - ("after" = Option, Query, description = "The cursor to start reading from"), - ("limit" = Option, Query, description = "The maximum number of entities to read"), - ), - responses( - ( - status = 200, - content_type = "application/json", - body = QueryEntitySubgraphResponse, - description = "A subgraph rooted at entities that satisfy the given query, each resolved to the requested depth.", - ), - (status = 422, content_type = "text/plain", description = "Provided query is invalid"), - (status = 500, description = "Store error occurred"), - ) -)] -async fn query_entity_subgraph( - AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, - InteractiveHeader(interactive): InteractiveHeader, - store_pool: Extension>, - temporal_client: Extension>>, - Extension(api_config): Extension, - mut query_logger: Option>, - Json(request): Json, -) -> Result>, BoxedResponse> -where - S: StorePool + Send + Sync, -{ - if let Some(query_logger) = &mut query_logger { - query_logger.capture(actor_id, OpenApiQuery::GetEntitySubgraph(&request)); - } - - let store = store_pool - .acquire(temporal_client.0) - .await - .map_err(report_to_response)?; - - let request = QueryEntitySubgraphRequest::deserialize(&request) - .map_err(Report::from) - .attach(hash_status::StatusCode::InvalidArgument) - .map_err(report_to_response)?; - let (query, options, traversal) = request.into_parts(); - - // TODO: https://linear.app/hash/issue/H-5351/reuse-parts-between-compilation-units - let mut heap = Heap::uninitialized(); - - if matches!(query, EntityQuery::Query { .. }) { - // The heap is going to be used in the compilation of the query and therefore needs to be - // primed. - // Doing this in a separate step allows us to be allocation free when not using HashQL - // queries. - heap.prime(); - } - - let filter = query.compile(&heap, CompilationOptions { interactive })?; - - let params = options - .into_traversal_params(filter, traversal, api_config) - .attach(hash_status::StatusCode::InvalidArgument) - .map_err(report_to_response)?; - - let response = store - .query_entity_subgraph(actor_id, params) - .await - .map(|response| { - Json(QueryEntitySubgraphResponse { - subgraph: response.subgraph.into(), - cursor: response.cursor.map(EntityQueryCursor::into_owned), - closed_multi_entity_types: response.closed_multi_entity_types, - definitions: response.definitions, - entity_permissions: response.entity_permissions, - }) - }) - .map_err(report_to_response); - if let Some(query_logger) = &mut query_logger { - query_logger.send().await.map_err(report_to_response)?; - } - response -} - -#[utoipa::path( - post, - path = "/entities/query/summarize", - request_body = SummarizeEntitiesParams, - tag = "Entity", - params( - ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), - - ), - responses( - ( - status = 200, - content_type = "application/json", - body = SummarizeEntitiesResponse, - ), - (status = 422, content_type = "text/plain", description = "Provided query is invalid"), - (status = 500, description = "Store error occurred"), - ) -)] -async fn summarize_entities( - AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, - store_pool: Extension>, - temporal_client: Extension>>, - mut query_logger: Option>, - Json(request): Json, -) -> Result, BoxedResponse> -where - S: StorePool + Send + Sync, -{ - if let Some(query_logger) = &mut query_logger { - query_logger.capture(actor_id, OpenApiQuery::SummarizeEntities(&request)); - } - - let store = store_pool - .acquire(temporal_client.0) - .await - .map_err(report_to_response)?; - - let response = store - .summarize_entities( - actor_id, - SummarizeEntitiesParams::deserialize(&request) - .map_err(Report::from) - .attach(hash_status::StatusCode::InvalidArgument) - .map_err(report_to_response)?, - ) - .await - .map(Json) - .map_err(report_to_response); - if let Some(query_logger) = &mut query_logger { - query_logger.send().await.map_err(report_to_response)?; - } - response -} - #[utoipa::path( patch, path = "/entities", diff --git a/libs/@local/graph/api/src/rest/entity/query/mod.rs b/libs/@local/graph/api/src/rest/entity/query/mod.rs new file mode 100644 index 00000000000..31e0b654121 --- /dev/null +++ b/libs/@local/graph/api/src/rest/entity/query/mod.rs @@ -0,0 +1,259 @@ +pub(crate) mod request; + +use alloc::sync::Arc; +use std::collections::HashMap; + +use axum::Extension; +use error_stack::{Report, ResultExt as _}; +use hash_graph_store::{ + entity::{ + ClosedMultiEntityTypeMap, CountEntitiesParams, EntityPermissions, EntityQueryCursor, + EntityStore as _, QueryEntitiesResponse, + }, + entity_type::EntityTypeResolveDefinitions, + pool::StorePool, +}; +use hash_temporal_client::TemporalClient; +use serde::Deserialize as _; +use serde_json::value::RawValue as RawJsonValue; +use type_system::{ + knowledge::entity::id::EntityId, + ontology::VersionedUrl, + principal::{actor::ActorEntityUuid, actor_group::WebId}, +}; + +pub use self::request::{ + QueryEntitiesRequest, QueryEntitySubgraphError, QueryEntitySubgraphRequest, +}; +use crate::rest::{ + ApiConfig, AuthenticatedUserHeader, OpenApiQuery, QueryLogger, + json::Json, + status::{BoxedResponse, report_to_response}, + utoipa_typedef::subgraph::Subgraph, +}; + +#[utoipa::path( + post, + path = "/entities/query", + request_body = QueryEntitiesRequest, + tag = "Entity", + params( + ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), + ("after" = Option, Query, description = "The cursor to start reading from"), + ("limit" = Option, Query, description = "The maximum number of entities to read"), + ), + responses( + ( + status = 200, + content_type = "application/json", + body = QueryEntitiesResponse, + description = "A list of entities that satisfy the given query.", + ), + (status = 422, content_type = "text/plain", description = "Provided query is invalid"), + (status = 500, description = "Store error occurred"), + ) +)] +pub(super) async fn query_entities( + AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, + store_pool: Extension>, + temporal_client: Extension>>, + Extension(api_config): Extension, + mut query_logger: Option>, + Json(request): Json>, +) -> Result>, BoxedResponse> +where + S: StorePool + Send + Sync, +{ + if let Some(query_logger) = &mut query_logger { + query_logger.capture(actor_id, OpenApiQuery::GetEntities(&request)); + } + + let store = store_pool + .acquire(temporal_client.0) + .await + .map_err(report_to_response)?; + + let request = QueryEntitiesRequest::deserialize(&*request) + .map_err(Report::from) + .map_err(report_to_response)?; + + let params = request + .into_params(api_config) + .attach(hash_status::StatusCode::InvalidArgument) + .map_err(report_to_response)?; + + let response = store + .query_entities(actor_id, params) + .await + .map(Json) + .map_err(report_to_response); + + if let Some(query_logger) = &mut query_logger { + query_logger.send().await.map_err(report_to_response)?; + } + response +} + +#[derive(serde::Serialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub(super) struct QueryEntitySubgraphResponse<'r> { + subgraph: Subgraph, + #[serde(borrow)] + cursor: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + closed_multi_entity_types: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + definitions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + web_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + created_by_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + edition_created_by_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + type_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + type_titles: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = false)] + entity_permissions: Option>, +} + +#[utoipa::path( + post, + path = "/entities/query/subgraph", + request_body = QueryEntitySubgraphRequest, + tag = "Entity", + params( + ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), + ("after" = Option, Query, description = "The cursor to start reading from"), + ("limit" = Option, Query, description = "The maximum number of entities to read"), + ), + responses( + ( + status = 200, + content_type = "application/json", + body = QueryEntitySubgraphResponse, + description = "A subgraph rooted at entities that satisfy the given query, each resolved to the requested depth.", + ), + (status = 422, content_type = "text/plain", description = "Provided query is invalid"), + (status = 500, description = "Store error occurred"), + ) +)] +pub(super) async fn query_entity_subgraph( + AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, + store_pool: Extension>, + temporal_client: Extension>>, + Extension(api_config): Extension, + mut query_logger: Option>, + Json(request): Json, +) -> Result>, BoxedResponse> +where + S: StorePool + Send + Sync, +{ + if let Some(query_logger) = &mut query_logger { + query_logger.capture(actor_id, OpenApiQuery::GetEntitySubgraph(&request)); + } + + let store = store_pool + .acquire(temporal_client.0) + .await + .map_err(report_to_response)?; + + let request = QueryEntitySubgraphRequest::deserialize(&request) + .map_err(Report::from) + .map_err(report_to_response)?; + + let params = request + .into_traversal_params(api_config) + .attach(hash_status::StatusCode::InvalidArgument) + .map_err(report_to_response)?; + + let response = store + .query_entity_subgraph(actor_id, params) + .await + .map(|response| { + Json(QueryEntitySubgraphResponse { + subgraph: response.subgraph.into(), + cursor: response.cursor.map(EntityQueryCursor::into_owned), + count: response.count, + closed_multi_entity_types: response.closed_multi_entity_types, + definitions: response.definitions, + web_ids: response.web_ids, + created_by_ids: response.created_by_ids, + edition_created_by_ids: response.edition_created_by_ids, + type_ids: response.type_ids, + type_titles: response.type_titles, + entity_permissions: response.entity_permissions, + }) + }) + .map_err(report_to_response); + if let Some(query_logger) = &mut query_logger { + query_logger.send().await.map_err(report_to_response)?; + } + response +} + +#[utoipa::path( + post, + path = "/entities/query/count", + request_body = CountEntitiesParams, + tag = "Entity", + params( + ("X-Authenticated-User-Actor-Id" = ActorEntityUuid, Header, description = "The ID of the actor which is used to authorize the request"), + + ), + responses( + ( + status = 200, + content_type = "application/json", + body = usize, + ), + (status = 422, content_type = "text/plain", description = "Provided query is invalid"), + (status = 500, description = "Store error occurred"), + ) +)] +pub(super) async fn count_entities( + AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, + store_pool: Extension>, + temporal_client: Extension>>, + mut query_logger: Option>, + Json(request): Json, +) -> Result, BoxedResponse> +where + S: StorePool + Send + Sync, +{ + if let Some(query_logger) = &mut query_logger { + query_logger.capture(actor_id, OpenApiQuery::CountEntities(&request)); + } + + let store = store_pool + .acquire(temporal_client.0) + .await + .map_err(report_to_response)?; + + let response = store + .count_entities( + actor_id, + CountEntitiesParams::deserialize(&request) + .map_err(Report::from) + .map_err(report_to_response)?, + ) + .await + .map(Json) + .map_err(report_to_response); + if let Some(query_logger) = &mut query_logger { + query_logger.send().await.map_err(report_to_response)?; + } + response +} diff --git a/libs/@local/graph/api/src/rest/entity/query/request.rs b/libs/@local/graph/api/src/rest/entity/query/request.rs new file mode 100644 index 00000000000..dc256dd8f00 --- /dev/null +++ b/libs/@local/graph/api/src/rest/entity/query/request.rs @@ -0,0 +1,836 @@ +use error_stack::{Report, ResultExt as _}; +use hash_graph_store::{ + entity::{ + EntityQueryCursor, EntityQueryPath, EntityQuerySorting, EntityQuerySortingRecord, + QueryConversion, QueryEntitiesParams, QueryEntitySubgraphParams, + }, + entity_type::IncludeEntityTypeOption, + filter::Filter, + query::Ordering, + subgraph::{ + edges::{ + EntityTraversalPath, GraphResolveDepths, MAX_TRAVERSAL_PATHS, SubgraphTraversalParams, + TraversalDepthError, TraversalPath, + }, + temporal_axes::QueryTemporalAxesUnresolved, + }, +}; +use type_system::knowledge::Entity; + +use crate::rest::{ApiConfig, LimitExceededError, resolve_limit}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, derive_more::Display)] +pub enum QueryEntitySubgraphError { + #[display("Query limit exceeded")] + Limit, + #[display("Traversal depth exceeded")] + TraversalDepth, + #[display("Resolve depth exceeded")] + ResolveDepth, +} + +impl core::error::Error for QueryEntitySubgraphError {} + +fn validate_traversal( + params: &SubgraphTraversalParams, +) -> Result<(), Report> { + match params { + SubgraphTraversalParams::Paths { traversal_paths } => { + if traversal_paths.len() > MAX_TRAVERSAL_PATHS { + return Err(Report::new(TraversalDepthError::TooManyPaths { + actual: traversal_paths.len(), + max: MAX_TRAVERSAL_PATHS, + }) + .change_context(QueryEntitySubgraphError::TraversalDepth)); + } + for path in traversal_paths { + path.validate() + .change_context(QueryEntitySubgraphError::TraversalDepth)?; + } + } + SubgraphTraversalParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + } => { + if traversal_paths.len() > MAX_TRAVERSAL_PATHS { + return Err(Report::new(TraversalDepthError::TooManyPaths { + actual: traversal_paths.len(), + max: MAX_TRAVERSAL_PATHS, + }) + .change_context(QueryEntitySubgraphError::TraversalDepth)); + } + for path in traversal_paths { + path.validate() + .change_context(QueryEntitySubgraphError::TraversalDepth)?; + } + graph_resolve_depths + .validate() + .change_context(QueryEntitySubgraphError::ResolveDepth)?; + } + } + Ok(()) +} + +#[tracing::instrument(level = "info", skip_all)] +fn generate_sorting_paths( + paths: Option>>, + temporal_axes: &QueryTemporalAxesUnresolved, +) -> Vec> { + let temporal_axes_sorting_path = match temporal_axes { + QueryTemporalAxesUnresolved::TransactionTime { .. } => &EntityQueryPath::TransactionTime, + QueryTemporalAxesUnresolved::DecisionTime { .. } => &EntityQueryPath::DecisionTime, + }; + + paths + .map_or_else( + || { + vec![ + EntityQuerySortingRecord { + path: temporal_axes_sorting_path.clone(), + ordering: Ordering::Descending, + nulls: None, + }, + EntityQuerySortingRecord { + path: EntityQueryPath::Uuid, + ordering: Ordering::Ascending, + nulls: None, + }, + EntityQuerySortingRecord { + path: EntityQueryPath::WebId, + ordering: Ordering::Ascending, + nulls: None, + }, + ] + }, + |mut paths| { + let mut has_temporal_axis = false; + let mut has_uuid = false; + let mut has_web_id = false; + + for path in &paths { + if path.path == EntityQueryPath::TransactionTime + || path.path == EntityQueryPath::DecisionTime + { + has_temporal_axis = true; + } + if path.path == EntityQueryPath::Uuid { + has_uuid = true; + } + if path.path == EntityQueryPath::WebId { + has_web_id = true; + } + } + + if !has_temporal_axis { + paths.push(EntityQuerySortingRecord { + path: temporal_axes_sorting_path.clone(), + ordering: Ordering::Descending, + nulls: None, + }); + } + if !has_uuid { + paths.push(EntityQuerySortingRecord { + path: EntityQueryPath::Uuid, + ordering: Ordering::Ascending, + nulls: None, + }); + } + if !has_web_id { + paths.push(EntityQuerySortingRecord { + path: EntityQueryPath::WebId, + ordering: Ordering::Ascending, + nulls: None, + }); + } + + paths + }, + ) + .into_iter() + .map(EntityQuerySortingRecord::into_owned) + .collect() +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +#[expect( + clippy::struct_excessive_bools, + reason = "Parameter struct deserialized from JSON" +)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct QueryEntitiesRequest<'q, 's, 'p> { + #[serde(borrow)] + pub filter: Filter<'q, Entity>, + + pub temporal_axes: QueryTemporalAxesUnresolved, + pub include_drafts: bool, + pub limit: Option, + #[serde(borrow, default)] + pub conversions: Vec>, + #[serde(borrow)] + pub sorting_paths: Option>>, + #[serde(borrow)] + pub cursor: Option>, + #[serde(default)] + pub include_count: bool, + #[serde(default)] + pub include_entity_types: Option, + #[serde(default)] + pub include_web_ids: bool, + #[serde(default)] + pub include_created_by_ids: bool, + #[serde(default)] + pub include_edition_created_by_ids: bool, + #[serde(default)] + pub include_type_ids: bool, + #[serde(default)] + pub include_type_titles: bool, + pub include_permissions: bool, +} + +impl<'q, 'p> QueryEntitiesRequest<'q, '_, 'p> { + /// Convert this request into [`QueryEntitiesParams`] with the given [`ApiConfig`] and resolved + /// limit. + /// + /// Does not validate that the resolved limit does not exceed [`ApiConfig::query_entity_limit`]. + pub fn into_params_unchecked( + self, + config: ApiConfig, + limit: Option, + ) -> QueryEntitiesParams<'q> + where + 'p: 'q, + { + let limit = limit.or(self.limit).unwrap_or(config.query_entity_limit); + + QueryEntitiesParams { + filter: self.filter, + sorting: EntityQuerySorting { + paths: generate_sorting_paths(self.sorting_paths, &self.temporal_axes), + cursor: self.cursor.map(EntityQueryCursor::into_owned), + }, + limit, + conversions: self.conversions, + include_drafts: self.include_drafts, + include_count: self.include_count, + include_entity_types: self.include_entity_types, + temporal_axes: self.temporal_axes, + include_web_ids: self.include_web_ids, + include_created_by_ids: self.include_created_by_ids, + include_edition_created_by_ids: self.include_edition_created_by_ids, + include_type_ids: self.include_type_ids, + include_type_titles: self.include_type_titles, + include_permissions: self.include_permissions, + } + } + + /// Convert this request into [`QueryEntitiesParams`] with the given [`ApiConfig`] and resolved + /// limit. + /// + /// # Errors + /// + /// Returns [`LimitExceededError`] if the requested limit exceeds the configured maximum in + /// [`ApiConfig::query_entity_limit`]. + pub fn into_params( + self, + config: ApiConfig, + ) -> Result, Report> + where + 'p: 'q, + { + let limit = resolve_limit(self.limit, config.query_entity_limit)?; + + Ok(self.into_params_unchecked(config, Some(limit))) + } +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +#[serde(untagged, deny_unknown_fields)] +pub enum QueryEntitySubgraphRequest<'q, 's, 'p> { + #[serde(rename_all = "camelCase")] + ResolveDepths { + traversal_paths: Vec, + graph_resolve_depths: GraphResolveDepths, + #[serde(borrow, flatten)] + request: QueryEntitiesRequest<'q, 's, 'p>, + }, + #[serde(rename_all = "camelCase")] + Paths { + traversal_paths: Vec, + #[serde(borrow, flatten)] + request: QueryEntitiesRequest<'q, 's, 'p>, + }, +} + +impl<'q, 's, 'p> QueryEntitySubgraphRequest<'q, 's, 'p> { + #[must_use] + pub fn into_parts(self) -> (QueryEntitiesRequest<'q, 's, 'p>, SubgraphTraversalParams) { + match self { + QueryEntitySubgraphRequest::Paths { + traversal_paths, + request: options, + } => (options, SubgraphTraversalParams::Paths { traversal_paths }), + QueryEntitySubgraphRequest::ResolveDepths { + traversal_paths, + graph_resolve_depths, + request: options, + } => ( + options, + SubgraphTraversalParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + }, + ), + } + } + + #[must_use] + pub fn from_parts( + request: QueryEntitiesRequest<'q, 's, 'p>, + params: SubgraphTraversalParams, + ) -> Self { + match params { + SubgraphTraversalParams::Paths { traversal_paths } => { + QueryEntitySubgraphRequest::Paths { + traversal_paths, + request, + } + } + SubgraphTraversalParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + } => QueryEntitySubgraphRequest::ResolveDepths { + traversal_paths, + graph_resolve_depths, + request, + }, + } + } + + /// Convert the request into traversal parameters. Skipping validation. + #[must_use] + pub fn into_traversal_params_unchecked(self, config: ApiConfig) -> QueryEntitySubgraphParams<'q> + where + 'p: 'q, + { + let (request, params) = self.into_parts(); + let request = request.into_params_unchecked(config, None); + + match params { + SubgraphTraversalParams::Paths { traversal_paths } => { + QueryEntitySubgraphParams::Paths { + traversal_paths, + request, + } + } + SubgraphTraversalParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + } => QueryEntitySubgraphParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + request, + }, + } + } + + /// Convert the request into traversal parameters. + /// + /// # Errors + /// + /// Returns [`QueryEntitySubgraphError`] if: + /// - The requested limit exceeds the configured maximum. + /// - The number of traversal paths exceeds [`MAX_TRAVERSAL_PATHS`]. + /// - Any traversal path exceeds the maximum edge count. + /// - Graph resolve depths exceed the allowed maximum. + pub fn into_traversal_params( + self, + config: ApiConfig, + ) -> Result, Report> + where + 'p: 'q, + { + let (request, params) = self.into_parts(); + + validate_traversal(¶ms)?; + + let request = request + .into_params(config) + .change_context(QueryEntitySubgraphError::Limit)?; + + match params { + SubgraphTraversalParams::Paths { traversal_paths } => { + Ok(QueryEntitySubgraphParams::Paths { + traversal_paths, + request, + }) + } + SubgraphTraversalParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + } => Ok(QueryEntitySubgraphParams::ResolveDepths { + traversal_paths, + graph_resolve_depths, + request, + }), + } + } +} + +#[cfg(test)] +mod tests { + use core::assert_matches; + + use serde_json::json; + + use super::*; + + /// Minimal valid temporal axes for test payloads. + fn temporal_axes() -> serde_json::Value { + json!({ + "pinned": { + "axis": "transactionTime", + "timestamp": null + }, + "variable": { + "axis": "decisionTime", + "interval": { + "start": null, + "end": null + } + } + }) + } + + /// Minimal valid request body shared across tests. + fn base_request() -> String { + json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string() + } + + #[test] + fn deserialize_minimal_entity_request() { + let payload = base_request(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitiesRequest { + include_drafts: false, + include_permissions: false, + limit: None, + sorting_paths: None, + cursor: None, + include_count: false, + include_entity_types: None, + include_web_ids: false, + include_created_by_ids: false, + include_edition_created_by_ids: false, + include_type_ids: false, + include_type_titles: false, + .. + }) + ); + } + + #[test] + fn deserialize_entity_request_with_all_fields() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": true, + "includePermissions": true, + "limit": 50, + "includeCount": true, + "includeWebIds": true, + "includeCreatedByIds": true, + "includeEditionCreatedByIds": true, + "includeTypeIds": true, + "includeTypeTitles": true + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitiesRequest { + include_drafts: true, + include_permissions: true, + limit: Some(50), + include_count: true, + include_web_ids: true, + include_created_by_ids: true, + include_edition_created_by_ids: true, + include_type_ids: true, + include_type_titles: true, + .. + }) + ); + } + + #[test] + fn reject_entity_request_missing_filter() { + let payload = json!({ + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("missing filter should fail") + .to_string(); + assert!(err.starts_with("missing field `filter`"), "{err}"); + } + + #[test] + fn reject_entity_request_missing_temporal_axes() { + let payload = json!({ + "filter": { "all": [] }, + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("missing temporalAxes should fail") + .to_string(); + assert!(err.starts_with("missing field `temporalAxes`"), "{err}"); + } + + #[test] + fn reject_entity_request_missing_include_drafts() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includePermissions": false + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("missing includeDrafts should fail") + .to_string(); + assert!(err.starts_with("missing field `includeDrafts`"), "{err}"); + } + + #[test] + fn reject_entity_request_missing_include_permissions() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("missing includePermissions should fail") + .to_string(); + assert!( + err.starts_with("missing field `includePermissions`"), + "{err}" + ); + } + + #[test] + fn deserialize_subgraph_paths_variant() { + let payload = json!({ + "traversalPaths": [ + { + "edges": [ + { "kind": "has-left-entity", "direction": "incoming" }, + { "kind": "has-right-entity", "direction": "outgoing" } + ] + } + ], + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitySubgraphRequest::Paths { + traversal_paths, + request: QueryEntitiesRequest { include_drafts: false, .. }, + }) if traversal_paths.len() == 1 && traversal_paths[0].edges.len() == 2 + ); + } + + #[test] + fn deserialize_subgraph_resolve_depths_variant() { + let payload = json!({ + "traversalPaths": [], + "graphResolveDepths": { + "inheritsFrom": 0, + "constrainsValuesOn": 0, + "constrainsPropertiesOn": 0, + "constrainsLinksOn": 0, + "constrainsLinkDestinationsOn": 0, + "isOfType": false + }, + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitySubgraphRequest::ResolveDepths { + traversal_paths, + graph_resolve_depths: GraphResolveDepths { + inherits_from: 0, + is_of_type: false, + .. + }, + request: QueryEntitiesRequest { include_drafts: false, .. }, + }) if traversal_paths.is_empty() + ); + } + + #[test] + fn reject_subgraph_missing_traversal_paths() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("missing traversalPaths should fail") + .to_string(); + assert!( + err.starts_with( + "data did not match any variant of untagged enum QueryEntitySubgraphRequest" + ), + "{err}" + ); + } + + #[test] + fn deserialize_filter_request_with_limit_and_count() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "limit": 100, + "includeCount": true, + "includePermissions": false + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitiesRequest { + limit: Some(100), + include_count: true, + .. + }) + ); + } + + #[test] + fn deserialize_subgraph_resolve_depths_with_traversal() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "graphResolveDepths": { + "inheritsFrom": 255, + "constrainsValuesOn": 255, + "constrainsPropertiesOn": 255, + "constrainsLinksOn": 255, + "constrainsLinkDestinationsOn": 255, + "isOfType": true + }, + "traversalPaths": [ + { + "edges": [ + { "kind": "has-left-entity", "direction": "incoming" }, + { "kind": "has-right-entity", "direction": "outgoing" } + ] + } + ], + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitySubgraphRequest::ResolveDepths { + traversal_paths, + graph_resolve_depths: GraphResolveDepths { + inherits_from: 255, + is_of_type: true, + .. + }, + request: QueryEntitiesRequest { include_permissions: false, .. }, + }) if traversal_paths.len() == 1 + ); + } + + #[test] + fn reject_resolve_depths_with_non_entity_edge() { + // If traversalPaths contains an ontology edge (e.g. "is-of-type"), it can't + // deserialize as EntityTraversalPath. The untagged enum must not silently + // fall through to the Paths variant, dropping graphResolveDepths. + let payload = json!({ + "traversalPaths": [ + { + "edges": [ + { "kind": "is-of-type" } + ] + } + ], + "graphResolveDepths": { + "inheritsFrom": 255, + "constrainsValuesOn": 255, + "constrainsPropertiesOn": 255, + "constrainsLinksOn": 255, + "constrainsLinkDestinationsOn": 255, + "isOfType": true + }, + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + let result = serde_json::from_str::>(&payload); + + match result { + Err(_) => {} // Correctly rejected + Ok(QueryEntitySubgraphRequest::ResolveDepths { .. }) => { + panic!("should not parse ontology edges as EntityTraversalPath"); + } + Ok(QueryEntitySubgraphRequest::Paths { .. }) => { + panic!("silently fell through to Paths variant, dropping graphResolveDepths"); + } + } + } + + #[test] + fn deserialize_paths_with_ontology_edge() { + // Ontology edges (like is-of-type) are valid in TraversalPath but not + // EntityTraversalPath. Without graphResolveDepths, this should parse as Paths. + let payload = json!({ + "traversalPaths": [ + { + "edges": [ + { "kind": "has-left-entity", "direction": "incoming" }, + { "kind": "is-of-type" } + ] + } + ], + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitySubgraphRequest::Paths { + traversal_paths, + .. + }) if traversal_paths.len() == 1 && traversal_paths[0].edges.len() == 2 + ); + } + + #[test] + fn reject_entity_request_unknown_field() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false, + "bogusField": 42 + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("unknown field should be rejected") + .to_string(); + assert!(err.contains("bogusField"), "{err}"); + } + + #[test] + fn reject_subgraph_unknown_field_through_flatten() { + // The subgraph enum uses `#[serde(flatten)]` on the inner request. + // Verify that `deny_unknown_fields` still catches unknown keys that + // would pass through the flattened struct boundary. + let payload = json!({ + "traversalPaths": [ + { + "edges": [ + { "kind": "has-left-entity", "direction": "incoming" } + ] + } + ], + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false, + "bogusField": 42 + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("unknown field through flatten should be rejected") + .to_string(); + // With untagged + flatten, serde reports "did not match any variant" + // because both variants reject the unknown field. + assert!( + err.contains("bogusField") || err.contains("did not match any variant"), + "{err}" + ); + } + + #[test] + fn reject_subgraph_resolve_depths_unknown_field_through_flatten() { + let payload = json!({ + "traversalPaths": [], + "graphResolveDepths": { + "inheritsFrom": 0, + "constrainsValuesOn": 0, + "constrainsPropertiesOn": 0, + "constrainsLinksOn": 0, + "constrainsLinkDestinationsOn": 0, + "isOfType": false + }, + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "includeDrafts": false, + "includePermissions": false, + "sneakyExtra": true + }) + .to_string(); + let err = serde_json::from_str::>(&payload) + .expect_err("unknown field through flatten should be rejected") + .to_string(); + assert!( + err.contains("sneakyExtra") || err.contains("did not match any variant"), + "{err}" + ); + } + + #[test] + fn deserialize_subgraph_paths_with_traversal() { + let payload = json!({ + "filter": { "all": [] }, + "temporalAxes": temporal_axes(), + "traversalPaths": [ + { + "edges": [ + { "kind": "has-left-entity", "direction": "incoming" }, + { "kind": "has-right-entity", "direction": "outgoing" } + ] + } + ], + "includeDrafts": false, + "includePermissions": false + }) + .to_string(); + assert_matches!( + serde_json::from_str::>(&payload), + Ok(QueryEntitySubgraphRequest::Paths { + traversal_paths, + request: QueryEntitiesRequest { include_permissions: false, .. }, + }) if traversal_paths.len() == 1 && traversal_paths[0].edges.len() == 2 + ); + } +} diff --git a/libs/@local/graph/api/src/rest/entity_query_request.rs b/libs/@local/graph/api/src/rest/entity_query_request.rs deleted file mode 100644 index 1d5d0fa962b..00000000000 --- a/libs/@local/graph/api/src/rest/entity_query_request.rs +++ /dev/null @@ -1,975 +0,0 @@ -//! Request types for entity queries. -//! -//! Contains the deserialization structs for both simple entity queries and subgraph requests. -//! Some design choices may look odd due to serde/OpenAPI limitations we need to work around: -//! -//! - Uses proxy structs for deserialization because `RawValue` doesn't play nice with `untagged` + -//! `deny_unknown_fields` (forces intermediate representation). -//! - Subgraph enum has 4 variants instead of nested structs because openapi-generator uses `&` -//! instead of `|` for nested `oneOf` constraints. -//! - Outer enum instead of nested enum because utoipa generates `allOf` constraints (merges all -//! fields into one type). With discriminator on the outer edge we get `oneOf` (proper union), but -//! openapi-generator can't handle nested oneOf and merges them anyway - so we flatten everything -//! - Lots of boolean fields instead of option structs for the same reason -//! -//! When changing any of these types, make sure that the OpenAPI generator types do not degenerate -//! into any of these cases. -use alloc::borrow::Cow; -use core::{cmp, ops::Range}; - -use axum::{ - Json, - response::{Html, IntoResponse as _}, -}; -use error_stack::{Report, ResultExt as _}; -use hash_graph_store::{ - entity::{ - EntityQueryCursor, EntityQueryPath, EntityQuerySorting, EntityQuerySortingRecord, - QueryConversion, QueryEntitiesParams, QueryEntitySubgraphParams, SearchEntitiesFilter, - SearchEntitiesParams, - }, - entity_type::IncludeEntityTypeOption, - filter::{Filter, SemanticDistance}, - query::Ordering, - subgraph::{ - edges::{ - EntityTraversalPath, GraphResolveDepths, MAX_TRAVERSAL_PATHS, - ResolveDepthExceededError, SubgraphTraversalParams, SubgraphTraversalValidationError, - TraversalDepthError, TraversalPath, TraversalPathConversionError, - }, - temporal_axes::QueryTemporalAxesUnresolved, - }, -}; -use hash_graph_types::Embedding; -use hashql_ast::error::AstDiagnosticCategory; -use hashql_core::{ - collections::fast_hash_map_with_capacity, - heap::Heap, - module::ModuleRegistry, - span::{SpanId, SpanTable}, - r#type::environment::Environment, -}; -use hashql_diagnostics::{ - DiagnosticIssues, Failure, Severity, Status, StatusExt as _, Success, - category::{DiagnosticCategory, canonical_category_id}, - diagnostic::render::{Format, RenderOptions}, - source::{DiagnosticSpan, Source, SourceId, Sources}, -}; -use hashql_eval::{ - error::EvalDiagnosticCategory, - graph::{error::GraphCompilerDiagnosticCategory, read::FilterSlice}, -}; -use hashql_hir::{error::HirDiagnosticCategory, visit::Visitor as _}; -use hashql_syntax_jexpr::{error::JExprDiagnosticCategory, span::Span}; -use http::StatusCode; -use serde::Deserialize; -use serde_json::value::RawValue as RawJsonValue; -use type_system::knowledge::Entity; -use utoipa::ToSchema; - -use super::{ - ApiConfig, LimitExceededError, SearchRequestError, resolve_limit, status::BoxedResponse, -}; - -#[tracing::instrument(level = "info", skip_all)] -fn generate_sorting_paths( - paths: Option>>, - temporal_axes: &QueryTemporalAxesUnresolved, -) -> Vec> { - let temporal_axes_sorting_path = match temporal_axes { - QueryTemporalAxesUnresolved::TransactionTime { .. } => &EntityQueryPath::TransactionTime, - QueryTemporalAxesUnresolved::DecisionTime { .. } => &EntityQueryPath::DecisionTime, - }; - - paths - .map_or_else( - || { - vec![ - EntityQuerySortingRecord { - path: temporal_axes_sorting_path.clone(), - ordering: Ordering::Descending, - nulls: None, - }, - EntityQuerySortingRecord { - path: EntityQueryPath::Uuid, - ordering: Ordering::Ascending, - nulls: None, - }, - EntityQuerySortingRecord { - path: EntityQueryPath::WebId, - ordering: Ordering::Ascending, - nulls: None, - }, - ] - }, - |mut paths| { - let mut has_temporal_axis = false; - let mut has_uuid = false; - let mut has_web_id = false; - - for path in &paths { - if path.path == EntityQueryPath::TransactionTime - || path.path == EntityQueryPath::DecisionTime - { - has_temporal_axis = true; - } - if path.path == EntityQueryPath::Uuid { - has_uuid = true; - } - if path.path == EntityQueryPath::WebId { - has_web_id = true; - } - } - - if !has_temporal_axis { - paths.push(EntityQuerySortingRecord { - path: temporal_axes_sorting_path.clone(), - ordering: Ordering::Descending, - nulls: None, - }); - } - if !has_uuid { - paths.push(EntityQuerySortingRecord { - path: EntityQueryPath::Uuid, - ordering: Ordering::Ascending, - nulls: None, - }); - } - if !has_web_id { - paths.push(EntityQuerySortingRecord { - path: EntityQueryPath::WebId, - ordering: Ordering::Ascending, - nulls: None, - }); - } - - paths - }, - ) - .into_iter() - .map(EntityQuerySortingRecord::into_owned) - .collect() -} - -/// Internal deserialization proxy for `QueryEntitiesRequest`. -/// -/// This struct is necessary because [`RawJsonValue`] cannot be used directly with -/// `#[serde(untagged, deny_unknown_fields)]` - these attributes force deserialization into an -/// intermediate representation, which cannot deserialize into a [`RawJsonValue`] as it materializes -/// the content. -/// -/// See and for more details. -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct FlatQueryEntitiesRequestData<'q, 's, 'p> { - // `QueryEntitiesQuery::Filter` - #[serde(borrow)] - filter: Option>, - // `QueryEntitiesQuery::Query`, - #[serde(borrow)] - query: Option<&'q RawJsonValue>, - - // `QueryEntitiesRequest` - temporal_axes: QueryTemporalAxesUnresolved, - include_drafts: bool, - limit: Option, - #[serde(borrow, default)] - conversions: Vec>, - #[serde(borrow)] - sorting_paths: Option>>, - #[serde(borrow)] - cursor: Option>, - #[serde(default)] - include_entity_types: Option, - include_permissions: bool, - - traversal_paths: Option>, - graph_resolve_depths: Option, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub(crate) struct CompilationOptions { - pub interactive: bool, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -enum HashQLDiagnosticCategory { - JExpr(JExprDiagnosticCategory), - Ast(AstDiagnosticCategory), - Hir(HirDiagnosticCategory), - Eval(EvalDiagnosticCategory), -} - -impl serde::Serialize for HashQLDiagnosticCategory { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.collect_str(&canonical_category_id(self)) - } -} - -impl DiagnosticCategory for HashQLDiagnosticCategory { - fn id(&self) -> Cow<'_, str> { - Cow::Borrowed("hashql") - } - - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed("HashQL") - } - - fn subcategory(&self) -> Option<&dyn DiagnosticCategory> { - match self { - Self::JExpr(jexpr) => Some(jexpr), - Self::Ast(ast) => Some(ast), - Self::Hir(hir) => Some(hir), - Self::Eval(eval) => Some(eval), - } - } -} - -#[derive(Debug, serde::Serialize)] -struct ResolvedSpan { - pub range: Range, - pub pointer: Option, -} - -fn resolve_span(id: SpanId, mut spans: &SpanTable) -> Option { - let absolute = DiagnosticSpan::absolute(&id, &mut spans)?; - let mut pointer = spans.get(id)?.pointer.as_ref().map(ToString::to_string); - - for ancestor in spans.ancestors(id) { - let Some(ancestor) = spans.get(ancestor) else { - continue; - }; - - if pointer.is_none() - && let Some(ancestor_pointer) = &ancestor.pointer - { - pointer = Some(ancestor_pointer.to_string()); - } - } - - Some(ResolvedSpan { - range: absolute.range().into(), - pointer, - }) -} - -fn issues_to_response( - issues: DiagnosticIssues, - severity: Severity, - source: &str, - mut spans: &SpanTable, - options: CompilationOptions, -) -> BoxedResponse { - let status_code = match severity { - Severity::Bug | Severity::Fatal => StatusCode::INTERNAL_SERVER_ERROR, - Severity::Error => StatusCode::BAD_REQUEST, - Severity::Warning | Severity::Note | Severity::Debug => StatusCode::CONFLICT, - }; - - let mut sources = Sources::new(); - sources.push(Source::new(source)); - - let mut response = if options.interactive { - let output = issues.render(RenderOptions::new(Format::Html, &sources), &mut spans); - - Html(output).into_response() - } else { - let diagnostics: Vec<_> = issues - .into_iter() - .map(|diagnostic| diagnostic.map_spans(|span| resolve_span(span, spans))) - .collect(); - - Json(diagnostics).into_response() - }; - - *response.status_mut() = status_code; - response.into() -} - -fn failure_to_response( - failure: Failure, - source: &str, - spans: &SpanTable, - options: CompilationOptions, -) -> BoxedResponse { - // Find the highest diagnostic level - let severity = cmp::max( - failure - .secondary - .iter() - .map(|diagnostic| diagnostic.severity) - .max() - .unwrap_or(Severity::Debug), - failure.primary.severity.into(), - ); - - issues_to_response(failure.into_issues(), severity, source, spans, options) -} - -#[derive(Debug, Clone)] -#[expect(clippy::large_enum_variant)] -pub enum EntityQuery<'q> { - Filter { filter: Filter<'q, Entity> }, - Query { query: &'q RawJsonValue }, -} - -impl<'q> EntityQuery<'q> { - fn compile_query<'heap>( - heap: &'heap Heap, - spans: &mut SpanTable, - query: &RawJsonValue, - ) -> Status, HashQLDiagnosticCategory, SpanId> { - // Parse the query - let mut parser = hashql_syntax_jexpr::Parser::new(heap, spans); - let mut ast = parser - .parse_expr(query.get().as_bytes()) - .map_err(|diagnostic| { - Failure::new(diagnostic.map_category(HashQLDiagnosticCategory::JExpr)) - })?; - - let mut env = Environment::new(heap); - let modules = ModuleRegistry::new(&env); - - // Lower the AST - let Success { - value: types, - advisories, - } = hashql_ast::lowering::lower(heap.intern_symbol("main"), &mut ast, &env, &modules) - .map_category(|category| { - HashQLDiagnosticCategory::Ast(AstDiagnosticCategory::Lowering(category)) - })?; - - let interner = hashql_hir::intern::Interner::new(heap); - let mut context = hashql_hir::context::HirContext::new(&interner, &modules); - - // Reify the HIR from the AST - let Success { - value: hir, - advisories, - } = hashql_hir::node::NodeData::from_ast(ast, &mut context, &types) - .map_category(|category| { - HashQLDiagnosticCategory::Hir(HirDiagnosticCategory::Reification(category)) - }) - .with_diagnostics(advisories)?; - - // Lower the HIR - let Success { - value: hir, - advisories, - } = hashql_hir::lower::lower(hir, &types, &mut env, &mut context) - .map_category(|category| { - HashQLDiagnosticCategory::Hir(HirDiagnosticCategory::Lowering(category)) - }) - .with_diagnostics(advisories)?; - - // Evaluate the HIR - // TODO: https://linear.app/hash/issue/BE-41/hashql-expose-input-in-graph-api - let inputs = fast_hash_map_with_capacity(0); - let mut compiler = hashql_eval::graph::read::GraphReadCompiler::new(heap, &inputs); - - compiler.visit_node(hir); - - let Success { - value: result, - advisories, - } = compiler - .finish() - .map_category(|category| { - HashQLDiagnosticCategory::Eval(EvalDiagnosticCategory::Graph( - GraphCompilerDiagnosticCategory::Read(category), - )) - }) - .with_diagnostics(advisories)?; - - let output = result.output.get(&hir.id).expect("TODO"); - - // Compile the Filter into one - let filters = match output { - FilterSlice::Entity { range } => result.filters.entity(range.clone()), - }; - - let filter = match filters { - [] => Filter::All(Vec::new()), - [filter] => filter.clone(), - _ => Filter::All(filters.to_vec()), - }; - - Ok(Success { - value: filter, - advisories, - }) - } - - /// Compiles a query into an executable entity filter. - /// - /// Transforms the query representation into a [`Filter`] that can be executed - /// against the entity store. For already-compiled filter queries, this returns - /// the filter directly. For raw HashQL queries, it parses and compiles them using - /// the provided `heap` arena allocator. - /// - /// # Errors - /// - /// Returns an error if the HashQL query cannot be compiled. - pub(crate) fn compile( - self, - heap: &'q Heap, - options: CompilationOptions, - ) -> Result, BoxedResponse> { - match self { - EntityQuery::Filter { filter } => Ok(filter), - EntityQuery::Query { query } => { - let mut spans = SpanTable::new(SourceId::new_unchecked(0x00)); - - let Success { - value: filter, - advisories, - } = Self::compile_query(heap, &mut spans, query).map_err(|failure| { - failure_to_response(failure, query.get(), &spans, options) - })?; - if !advisories.is_empty() { - // This isn't perfect, what we'd want instead is to return it alongside the - // response, the problem with that approach is just how: we'd need to adjust the - // return type, and respect interactive. Returning warnings before so that user - // can fix them before trying again seems to be the best approach for now. - return Err(issues_to_response( - advisories.generalize(), - Severity::Warning, - query.get(), - &spans, - options, - )); - } - - Ok(filter) - } - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, derive_more::Display)] -pub enum EntityQueryOptionsError { - #[display( - "Field '{field}' is only valid in subgraph requests. Use the subgraph endpoint instead." - )] - InvalidFieldForEntityQuery { field: &'static str }, - #[display( - "Field '{field}' is only valid in entity and subgraph requests. Use the entity endpoint \ - instead." - )] - InvalidFieldForEntityOptions { field: &'static str }, -} - -impl core::error::Error for EntityQueryOptionsError {} - -#[derive(Debug, Clone, Deserialize, ToSchema)] -#[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct EntityQueryOptions<'s, 'p> { - pub temporal_axes: QueryTemporalAxesUnresolved, - pub include_drafts: bool, - pub limit: Option, - #[serde(borrow, default)] - pub conversions: Vec>, - #[serde(borrow)] - pub sorting_paths: Option>>, - #[serde(borrow)] - pub cursor: Option>, - #[serde(default)] - pub include_entity_types: Option, - pub include_permissions: bool, -} - -impl<'q, 's, 'p> TryFrom> for EntityQueryOptions<'s, 'p> { - type Error = EntityQueryOptionsError; - - fn try_from(value: FlatQueryEntitiesRequestData<'q, 's, 'p>) -> Result { - let FlatQueryEntitiesRequestData { - filter, - query, - temporal_axes, - include_drafts, - limit, - conversions, - sorting_paths, - cursor, - include_entity_types, - include_permissions, - graph_resolve_depths, - traversal_paths, - } = value; - - if filter.is_some() { - return Err(EntityQueryOptionsError::InvalidFieldForEntityOptions { field: "filter" }); - } - - if query.is_some() { - return Err(EntityQueryOptionsError::InvalidFieldForEntityOptions { field: "query" }); - } - - if graph_resolve_depths.is_some() { - return Err(EntityQueryOptionsError::InvalidFieldForEntityQuery { - field: "graphResolveDepths", - }); - } - - if traversal_paths.is_some() { - return Err(EntityQueryOptionsError::InvalidFieldForEntityQuery { - field: "traversalPaths", - }); - } - - Ok(Self { - temporal_axes, - include_drafts, - limit, - conversions, - sorting_paths, - cursor, - include_entity_types, - include_permissions, - }) - } -} - -impl<'p> EntityQueryOptions<'_, 'p> { - /// # Errors - /// - /// Returns [`LimitExceededError`] if the requested limit exceeds the configured maximum in - /// [`ApiConfig::query_entity_limit`]. - pub fn into_params<'f>( - self, - filter: Filter<'f, Entity>, - config: ApiConfig, - ) -> Result, Report> - where - 'p: 'f, - { - let limit = resolve_limit(self.limit, config.query_entity_limit)?; - - Ok(QueryEntitiesParams { - filter, - sorting: EntityQuerySorting { - paths: generate_sorting_paths(self.sorting_paths, &self.temporal_axes), - cursor: self.cursor.map(EntityQueryCursor::into_owned), - }, - limit, - conversions: self.conversions, - include_drafts: self.include_drafts, - include_entity_types: self.include_entity_types, - temporal_axes: self.temporal_axes, - include_permissions: self.include_permissions, - }) - } - - /// # Errors - /// - /// Returns [`LimitExceededError`] if the requested limit exceeds the configured maximum in - /// [`ApiConfig::query_entity_limit`]. - pub fn into_traversal_params<'q>( - self, - filter: Filter<'q, Entity>, - traversal: SubgraphTraversalParams, - config: ApiConfig, - ) -> Result, Report> - where - 'p: 'q, - { - match traversal { - SubgraphTraversalParams::Paths { traversal_paths } => { - Ok(QueryEntitySubgraphParams::Paths { - traversal_paths, - request: self.into_params(filter, config)?, - }) - } - SubgraphTraversalParams::ResolveDepths { - traversal_paths, - graph_resolve_depths, - } => Ok(QueryEntitySubgraphParams::ResolveDepths { - traversal_paths, - graph_resolve_depths, - request: self.into_params(filter, config)?, - }), - } - } -} - -/// Request body for the entity embedding search endpoint. -#[derive(Debug, Deserialize, ToSchema)] -#[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct SearchEntitiesRequest { - pub embedding: Embedding<'static>, - pub maximum_semantic_distance: f64, - pub limit: Option, - #[serde(default)] - pub include_entity_types: bool, - #[serde(default)] - pub filter: SearchEntitiesFilter, -} - -impl SearchEntitiesRequest { - /// # Errors - /// - /// - [`InvalidSemanticDistance`] if the maximum semantic distance is invalid. - /// - [`LimitExceeded`] if the requested limit exceeds the configured maximum. - /// - /// [`InvalidSemanticDistance`]: [`SearchRequestError::InvalidSemanticDistance`] - /// [`LimitExceeded`]: [`SearchRequestError::LimitExceeded`] - pub fn into_params( - self, - config: ApiConfig, - ) -> Result> { - Ok(SearchEntitiesParams { - embedding: self.embedding, - maximum_semantic_distance: SemanticDistance::try_from(self.maximum_semantic_distance) - .change_context( - SearchRequestError::InvalidSemanticDistance, - )?, - limit: resolve_limit(self.limit, config.query_entity_limit) - .change_context(SearchRequestError::LimitExceeded)?, - include_entity_types: self.include_entity_types, - filter: self.filter, - }) - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, derive_more::Display, derive_more::From)] -enum QueryEntitiesRequestError { - #[from] - RequestOptions(EntityQueryOptionsError), - #[display("Missing required query parameter. Provide either 'filter' or 'query'.")] - MissingQueryParameter, - #[display("Conflicting query parameters. Provide either 'filter' or 'query', not both.")] - ConflictingQueryParameters, -} - -impl core::error::Error for QueryEntitiesRequestError {} - -#[derive(Debug, Clone, Deserialize, ToSchema)] -#[serde( - untagged, - try_from = "FlatQueryEntitiesRequestData", - deny_unknown_fields -)] -#[expect(clippy::large_enum_variant)] -pub enum QueryEntitiesRequest<'q, 's, 'p> { - #[serde(rename_all = "camelCase")] - Query { - #[serde(borrow)] - #[schema(value_type = utoipa::openapi::schema::Value)] - query: &'q RawJsonValue, - #[serde(borrow, flatten)] - options: EntityQueryOptions<'s, 'p>, - }, - #[serde(rename_all = "camelCase")] - Filter { - #[serde(borrow)] - filter: Filter<'q, Entity>, - #[serde(borrow, flatten)] - options: EntityQueryOptions<'s, 'p>, - }, -} - -impl<'q, 's, 'p> TryFrom> - for QueryEntitiesRequest<'q, 's, 'p> -{ - type Error = QueryEntitiesRequestError; - - fn try_from(mut value: FlatQueryEntitiesRequestData<'q, 's, 'p>) -> Result { - let filter = value.filter.take(); - let query = value.query.take(); - - match (filter, query) { - (None, None) => Err(QueryEntitiesRequestError::MissingQueryParameter), - (Some(_), Some(_)) => Err(QueryEntitiesRequestError::ConflictingQueryParameters), - (Some(filter), None) => Ok(Self::Filter { - filter, - options: value.try_into()?, - }), - (None, Some(query)) => Ok(Self::Query { - query, - options: value.try_into()?, - }), - } - } -} - -impl<'q, 's, 'p> QueryEntitiesRequest<'q, 's, 'p> { - #[must_use] - pub fn from_parts(query: EntityQuery<'q>, options: EntityQueryOptions<'s, 'p>) -> Self { - match query { - EntityQuery::Filter { filter } => Self::Filter { filter, options }, - EntityQuery::Query { query } => Self::Query { query, options }, - } - } - - #[must_use] - pub fn into_parts(self) -> (EntityQuery<'q>, EntityQueryOptions<'s, 'p>) { - match self { - QueryEntitiesRequest::Query { query, options } => { - (EntityQuery::Query { query }, options) - } - QueryEntitiesRequest::Filter { filter, options } => { - (EntityQuery::Filter { filter }, options) - } - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, derive_more::Display, derive_more::From)] -enum QueryEntitySubgraphRequestError { - #[from] - QueryEntityRequest(QueryEntitiesRequestError), - #[from] - UnsupportedGraphTraversalPath(TraversalPathConversionError), - #[display( - "Subgraph request missing traversal parameters. Specify either 'traversalPaths` and \ - optionally `graphResolveDepths'." - )] - MissingSubgraphTraversal, - #[from] - TraversalValidation(SubgraphTraversalValidationError), -} - -impl core::error::Error for QueryEntitySubgraphRequestError {} - -impl From for QueryEntitySubgraphRequestError { - fn from(err: TraversalDepthError) -> Self { - Self::TraversalValidation(err.into()) - } -} - -impl From for QueryEntitySubgraphRequestError { - fn from(err: ResolveDepthExceededError) -> Self { - Self::TraversalValidation(err.into()) - } -} - -#[derive(Debug, Clone, Deserialize, ToSchema)] -#[serde( - untagged, - try_from = "FlatQueryEntitiesRequestData", - deny_unknown_fields -)] -pub enum QueryEntitySubgraphRequest<'q, 's, 'p> { - #[serde(rename_all = "camelCase")] - ResolveDepthsWithQuery { - #[serde(borrow)] - #[schema(value_type = utoipa::openapi::schema::Value)] - query: &'q RawJsonValue, - traversal_paths: Vec, - graph_resolve_depths: GraphResolveDepths, - #[serde(borrow, flatten)] - options: EntityQueryOptions<'s, 'p>, - }, - #[serde(rename_all = "camelCase")] - ResolveDepthsWithFilter { - #[serde(borrow)] - filter: Filter<'q, Entity>, - traversal_paths: Vec, - graph_resolve_depths: GraphResolveDepths, - #[serde(borrow, flatten)] - options: EntityQueryOptions<'s, 'p>, - }, - #[serde(rename_all = "camelCase")] - PathsWithQuery { - #[serde(borrow)] - #[schema(value_type = utoipa::openapi::schema::Value)] - query: &'q RawJsonValue, - traversal_paths: Vec, - #[serde(borrow, flatten)] - options: EntityQueryOptions<'s, 'p>, - }, - #[serde(rename_all = "camelCase")] - PathsWithFilter { - #[serde(borrow)] - filter: Filter<'q, Entity>, - traversal_paths: Vec, - #[serde(borrow, flatten)] - options: EntityQueryOptions<'s, 'p>, - }, -} - -impl<'q, 's, 'p> TryFrom> - for QueryEntitySubgraphRequest<'q, 's, 'p> -{ - type Error = QueryEntitySubgraphRequestError; - - fn try_from(mut value: FlatQueryEntitiesRequestData<'q, 's, 'p>) -> Result { - let graph_resolve_depths = value.graph_resolve_depths.take(); - let traversal_paths = value - .traversal_paths - .take() - .ok_or(QueryEntitySubgraphRequestError::MissingSubgraphTraversal)?; - - if traversal_paths.len() > MAX_TRAVERSAL_PATHS { - return Err(TraversalDepthError::TooManyPaths { - actual: traversal_paths.len(), - max: MAX_TRAVERSAL_PATHS, - } - .into()); - } - - let request = value.try_into()?; - - match graph_resolve_depths { - None => { - for path in &traversal_paths { - path.validate()?; - } - match request { - QueryEntitiesRequest::Filter { filter, options } => { - Ok(QueryEntitySubgraphRequest::PathsWithFilter { - traversal_paths, - filter, - options, - }) - } - QueryEntitiesRequest::Query { query, options } => { - Ok(QueryEntitySubgraphRequest::PathsWithQuery { - traversal_paths, - query, - options, - }) - } - } - } - Some(graph_resolve_depths) => { - let entity_paths: Vec = traversal_paths - .into_iter() - .map(EntityTraversalPath::try_from) - .collect::>()?; - for path in &entity_paths { - path.validate()?; - } - graph_resolve_depths.validate()?; - match request { - QueryEntitiesRequest::Filter { filter, options } => { - Ok(QueryEntitySubgraphRequest::ResolveDepthsWithFilter { - traversal_paths: entity_paths, - graph_resolve_depths, - filter, - options, - }) - } - QueryEntitiesRequest::Query { query, options } => { - Ok(QueryEntitySubgraphRequest::ResolveDepthsWithQuery { - traversal_paths: entity_paths, - graph_resolve_depths, - query, - options, - }) - } - } - } - } - } -} - -impl<'q, 's, 'p> QueryEntitySubgraphRequest<'q, 's, 'p> { - #[must_use] - pub fn from_parts( - query: EntityQuery<'q>, - options: EntityQueryOptions<'s, 'p>, - traversal_params: SubgraphTraversalParams, - ) -> Self { - match (query, traversal_params) { - ( - EntityQuery::Filter { filter }, - SubgraphTraversalParams::Paths { traversal_paths }, - ) => Self::PathsWithFilter { - filter, - options, - traversal_paths, - }, - (EntityQuery::Query { query }, SubgraphTraversalParams::Paths { traversal_paths }) => { - Self::PathsWithQuery { - query, - traversal_paths, - options, - } - } - ( - EntityQuery::Filter { filter }, - SubgraphTraversalParams::ResolveDepths { - traversal_paths, - graph_resolve_depths, - }, - ) => Self::ResolveDepthsWithFilter { - filter, - options, - traversal_paths, - graph_resolve_depths, - }, - ( - EntityQuery::Query { query }, - SubgraphTraversalParams::ResolveDepths { - traversal_paths, - graph_resolve_depths, - }, - ) => Self::ResolveDepthsWithQuery { - query, - options, - traversal_paths, - graph_resolve_depths, - }, - } - } - - #[must_use] - pub fn into_parts( - self, - ) -> ( - EntityQuery<'q>, - EntityQueryOptions<'s, 'p>, - SubgraphTraversalParams, - ) { - match self { - QueryEntitySubgraphRequest::PathsWithQuery { - query, - traversal_paths, - options, - } => ( - EntityQuery::Query { query }, - options, - SubgraphTraversalParams::Paths { traversal_paths }, - ), - QueryEntitySubgraphRequest::PathsWithFilter { - filter, - traversal_paths, - options, - } => ( - EntityQuery::Filter { filter }, - options, - SubgraphTraversalParams::Paths { traversal_paths }, - ), - QueryEntitySubgraphRequest::ResolveDepthsWithQuery { - query, - traversal_paths, - graph_resolve_depths, - options, - } => ( - EntityQuery::Query { query }, - options, - SubgraphTraversalParams::ResolveDepths { - traversal_paths, - graph_resolve_depths, - }, - ), - QueryEntitySubgraphRequest::ResolveDepthsWithFilter { - filter, - traversal_paths, - graph_resolve_depths, - options, - } => ( - EntityQuery::Filter { filter }, - options, - SubgraphTraversalParams::ResolveDepths { - traversal_paths, - graph_resolve_depths, - }, - ), - } - } -} diff --git a/libs/@local/graph/api/src/rest/mod.rs b/libs/@local/graph/api/src/rest/mod.rs index 8e3167fde0e..c67668ab898 100644 --- a/libs/@local/graph/api/src/rest/mod.rs +++ b/libs/@local/graph/api/src/rest/mod.rs @@ -16,7 +16,6 @@ pub mod admin; pub mod http_tracing_layer; pub mod jwt; -mod entity_query_request; mod json; mod utoipa_typedef; use alloc::{borrow::Cow, sync::Arc}; diff --git a/libs/@local/hashql/compiletest/src/suite/eval_postgres.rs b/libs/@local/hashql/compiletest/src/suite/eval_postgres.rs index 7cd0bd37246..f38539e2592 100644 --- a/libs/@local/hashql/compiletest/src/suite/eval_postgres.rs +++ b/libs/@local/hashql/compiletest/src/suite/eval_postgres.rs @@ -7,7 +7,7 @@ use hashql_core::{ r#type::{TypeFormatter, TypeFormatterOptions, environment::Environment}, }; use hashql_diagnostics::DiagnosticIssues; -use hashql_eval::{context::EvalContext, postgres::PostgresCompiler}; +use hashql_eval::{context::CodeGenerationContext, postgres::PostgresCompiler}; use hashql_mir::{ body::{Body, basic_block::BasicBlockId, terminator::TerminatorKind}, context::MirContext, @@ -117,7 +117,8 @@ impl Suite for EvalPostgres { let mir_buf = format_mir_with_placement(heap, &environment, &bodies, &analysis); secondary_outputs.insert("mir", mir_buf); - let mut context = EvalContext::new_in( + let interner = interner.into(); + let mut context = CodeGenerationContext::new_in( &environment, &interner, &bodies, diff --git a/libs/@local/hashql/core/src/graph/linked.rs b/libs/@local/hashql/core/src/graph/linked.rs index 6494bd62ad1..1b5d19d6784 100644 --- a/libs/@local/hashql/core/src/graph/linked.rs +++ b/libs/@local/hashql/core/src/graph/linked.rs @@ -138,11 +138,13 @@ impl HasId for Edge { impl Edge { /// Returns the source node of this edge. + #[inline] pub const fn source(&self) -> NodeId { self.source } /// Returns the target node of this edge. + #[inline] pub const fn target(&self) -> NodeId { self.target } @@ -542,11 +544,9 @@ impl Default for LinkedGraph { } } -impl core::fmt::Debug - for LinkedGraph -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("LinkedGraph") +impl fmt::Debug for LinkedGraph { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("LinkedGraph") .field("nodes", &self.nodes) .field("edges", &self.edges) .finish() diff --git a/libs/@local/hashql/eval/src/context.rs b/libs/@local/hashql/eval/src/context.rs index e3c1958784b..5d59812c71c 100644 --- a/libs/@local/hashql/eval/src/context.rs +++ b/libs/@local/hashql/eval/src/context.rs @@ -1,4 +1,4 @@ -use core::{alloc::Allocator, ops::Index}; +use core::{alloc::Allocator, fmt, ops::Index}; use hashql_core::{ heap::BumpAllocator, id::bit_vec::DenseBitSet, r#type::environment::Environment, @@ -11,7 +11,6 @@ use hashql_mir::{ local::Local, }, def::{DefId, DefIdSlice, DefIdVec}, - intern::Interner, pass::{ analysis::dataflow::{ TraversalLivenessAnalysis, @@ -21,7 +20,7 @@ use hashql_mir::{ }, }; -use crate::error::EvalDiagnosticIssues; +use crate::{error::EvalDiagnosticIssues, intern::Interner}; struct BasicBlockLiveOut( Box, TraversalPathBitSet)>, A>, @@ -49,7 +48,7 @@ impl Index<(DefId, BasicBlockId)> for LiveOut { } } -pub struct EvalContext<'ctx, 'heap, A: Allocator> { +pub struct CodeGenerationContext<'ctx, 'heap, A: Allocator> { pub env: &'ctx Environment<'heap>, pub interner: &'ctx Interner<'heap>, @@ -58,10 +57,11 @@ pub struct EvalContext<'ctx, 'heap, A: Allocator> { pub live_out: LiveOut, pub diagnostics: EvalDiagnosticIssues, + pub alloc: A, } -impl<'ctx, 'heap, A: Allocator> EvalContext<'ctx, 'heap, A> { +impl<'ctx, 'heap, A: Allocator> CodeGenerationContext<'ctx, 'heap, A> { pub fn new_in( env: &'ctx Environment<'heap>, interner: &'ctx Interner<'heap>, @@ -121,3 +121,39 @@ impl<'ctx, 'heap, A: Allocator> EvalContext<'ctx, 'heap, A> { } } } + +#[derive(Copy, Clone)] +pub struct CodeExecutionContext<'ctx, 'heap, A: Allocator> { + pub env: &'ctx Environment<'heap>, + pub interner: &'ctx Interner<'heap>, + + pub bodies: &'ctx DefIdSlice>, + pub execution: &'ctx DefIdSlice>>, + + pub alloc: A, +} + +impl fmt::Debug for CodeExecutionContext<'_, '_, A> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("CodeExecutionContext") + .field("env", &self.env) + .field("interner", &self.interner) + .field("bodies", &self.bodies) + .field("execution", &self.execution) + .finish_non_exhaustive() + } +} + +impl<'ctx, 'heap, A: Allocator> From> + for CodeExecutionContext<'ctx, 'heap, A> +{ + fn from(context: CodeGenerationContext<'ctx, 'heap, A>) -> Self { + Self { + env: context.env, + interner: context.interner, + bodies: context.bodies, + execution: context.execution, + alloc: context.alloc, + } + } +} diff --git a/libs/@local/hashql/eval/src/error.rs b/libs/@local/hashql/eval/src/error.rs index bdfc6b5e556..91388d645fc 100644 --- a/libs/@local/hashql/eval/src/error.rs +++ b/libs/@local/hashql/eval/src/error.rs @@ -3,18 +3,17 @@ use alloc::borrow::Cow; use hashql_core::span::SpanId; use hashql_diagnostics::{Diagnostic, DiagnosticIssues, Severity, category::DiagnosticCategory}; -#[cfg(feature = "graph")] -use crate::graph::error::GraphCompilerDiagnosticCategory; -use crate::postgres::error::PostgresDiagnosticCategory; +use crate::{ + orchestrator::OrchestratorDiagnosticCategory, postgres::error::PostgresDiagnosticCategory, +}; pub type EvalDiagnostic = Diagnostic; pub type EvalDiagnosticIssues = DiagnosticIssues; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum EvalDiagnosticCategory { - #[cfg(feature = "graph")] - Graph(GraphCompilerDiagnosticCategory), Postgres(PostgresDiagnosticCategory), + Orchestrator(OrchestratorDiagnosticCategory), } impl DiagnosticCategory for EvalDiagnosticCategory { @@ -28,9 +27,8 @@ impl DiagnosticCategory for EvalDiagnosticCategory { fn subcategory(&self) -> Option<&dyn DiagnosticCategory> { match self { - #[cfg(feature = "graph")] - Self::Graph(graph) => Some(graph), Self::Postgres(postgres) => Some(postgres), + Self::Orchestrator(orchestrator) => Some(orchestrator), } } } diff --git a/libs/@local/hashql/eval/src/intern.rs b/libs/@local/hashql/eval/src/intern.rs new file mode 100644 index 00000000000..dee91f141df --- /dev/null +++ b/libs/@local/hashql/eval/src/intern.rs @@ -0,0 +1,28 @@ +use hashql_core::{intern::InternSet, symbol::Symbol}; + +/// Interner for the evaluation stage. +/// +/// Must be created from the MIR interner via [`From`] to preserve +/// [`Interned`](hashql_core::intern::Interned) pointer identity across +/// the MIR/eval boundary. +#[derive(Debug)] +pub struct Interner<'heap> { + pub symbols: InternSet<'heap, [Symbol<'heap>]>, +} + +#[cfg(test)] +impl<'heap> Interner<'heap> { + pub(crate) fn testing(heap: &'heap hashql_core::heap::Heap) -> Self { + Self { + symbols: InternSet::new(heap), + } + } +} + +impl<'heap> From> for Interner<'heap> { + fn from(interner: hashql_mir::intern::Interner<'heap>) -> Self { + Self { + symbols: interner.symbols, + } + } +} diff --git a/libs/@local/hashql/eval/src/lib.rs b/libs/@local/hashql/eval/src/lib.rs index 16a54ff5638..fbadfc201b7 100644 --- a/libs/@local/hashql/eval/src/lib.rs +++ b/libs/@local/hashql/eval/src/lib.rs @@ -27,6 +27,7 @@ pub mod context; pub mod error; #[cfg(feature = "graph")] pub mod graph; +pub mod intern; pub mod orchestrator; pub mod postgres; diff --git a/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs b/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs index 0c8fd4dcfd0..85883ba9a9b 100644 --- a/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs +++ b/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs @@ -42,7 +42,7 @@ mod tests; /// [`Unknown`]: hashql_core::type::kind::TypeKind::Unknown pub struct Decoder<'env, 'heap, A> { env: &'env Environment<'heap>, - interner: &'env hashql_mir::intern::Interner<'heap>, + interner: &'env crate::intern::Interner<'heap>, alloc: A, } @@ -50,7 +50,7 @@ pub struct Decoder<'env, 'heap, A> { impl<'env, 'heap, A: Allocator> Decoder<'env, 'heap, A> { pub const fn new( env: &'env Environment<'heap>, - interner: &'env hashql_mir::intern::Interner<'heap>, + interner: &'env crate::intern::Interner<'heap>, alloc: A, ) -> Self { Self { diff --git a/libs/@local/hashql/eval/src/orchestrator/codec/decode/tests.rs b/libs/@local/hashql/eval/src/orchestrator/codec/decode/tests.rs index d632d0eec36..64ca25aa1f4 100644 --- a/libs/@local/hashql/eval/src/orchestrator/codec/decode/tests.rs +++ b/libs/@local/hashql/eval/src/orchestrator/codec/decode/tests.rs @@ -6,12 +6,10 @@ use hashql_core::{ symbol::sym, r#type::{TypeId, builder::TypeBuilder, environment::Environment}, }; -use hashql_mir::{ - intern::Interner, - interpret::value::{self, Value}, -}; +use hashql_mir::interpret::value::{self, Value}; use super::{DecodeError, Decoder, JsonValueRef}; +use crate::intern::Interner; fn str_value(content: &str) -> Value<'_, Global> { Value::String(value::Str::from(Rc::::from(content))) @@ -28,7 +26,7 @@ fn decoder<'env, 'heap>( fn primitive_string() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -42,7 +40,7 @@ fn primitive_string() { fn primitive_integer() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -57,7 +55,7 @@ fn primitive_integer() { fn primitive_number() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -72,7 +70,7 @@ fn primitive_number() { fn primitive_boolean_true() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -89,7 +87,7 @@ fn primitive_boolean_true() { fn primitive_boolean_false() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -106,7 +104,7 @@ fn primitive_boolean_false() { fn primitive_null() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -120,7 +118,7 @@ fn primitive_null() { fn primitive_type_mismatch() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -132,7 +130,7 @@ fn primitive_type_mismatch() { fn struct_matching_fields() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -157,7 +155,7 @@ fn struct_matching_fields() { fn struct_missing_field() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -174,7 +172,7 @@ fn struct_missing_field() { fn struct_extra_field() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -192,7 +190,7 @@ fn struct_extra_field() { fn tuple_correct_length() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -221,7 +219,7 @@ fn tuple_correct_length() { fn tuple_length_mismatch() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -236,7 +234,7 @@ fn tuple_length_mismatch() { fn union_first_variant_matches() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -253,7 +251,7 @@ fn union_first_variant_matches() { fn union_second_variant_matches() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -269,7 +267,7 @@ fn union_second_variant_matches() { fn union_no_variant_matches() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -283,7 +281,7 @@ fn union_no_variant_matches() { fn opaque_wraps_inner() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -303,7 +301,7 @@ fn opaque_wraps_inner() { fn list_intrinsic() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -329,7 +327,7 @@ fn list_intrinsic() { fn dict_intrinsic() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -360,7 +358,7 @@ fn dict_intrinsic() { fn intersection_type_error() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -374,7 +372,7 @@ fn intersection_type_error() { fn closure_type_error() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -388,7 +386,7 @@ fn closure_type_error() { fn never_type_error() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -400,7 +398,7 @@ fn never_type_error() { fn unknown_type_integer_fallback() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -415,7 +413,7 @@ fn unknown_type_integer_fallback() { fn unknown_type_float_fallback() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -430,7 +428,7 @@ fn unknown_type_float_fallback() { fn unknown_type_array_becomes_list() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -450,7 +448,7 @@ fn unknown_type_array_becomes_list() { fn unknown_type_non_url_object_becomes_dict() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); @@ -469,7 +467,7 @@ fn unknown_type_non_url_object_becomes_dict() { fn unknown_type_url_object_becomes_struct() { let heap = Heap::new(); let env = Environment::new(&heap); - let interner = Interner::new(&heap); + let interner = Interner::testing(&heap); let types = TypeBuilder::synthetic(&env); let decoder = decoder(&env, &interner); diff --git a/libs/@local/hashql/eval/src/orchestrator/events.rs b/libs/@local/hashql/eval/src/orchestrator/events.rs index e2c6b4cd3b6..dd9e3ce337c 100644 --- a/libs/@local/hashql/eval/src/orchestrator/events.rs +++ b/libs/@local/hashql/eval/src/orchestrator/events.rs @@ -146,6 +146,7 @@ impl AppendEventLog { } impl Default for AppendEventLog { + #[inline] fn default() -> Self { Self::new() } diff --git a/libs/@local/hashql/eval/src/orchestrator/mod.rs b/libs/@local/hashql/eval/src/orchestrator/mod.rs index 56e7acb85e2..7cb4198ecaa 100644 --- a/libs/@local/hashql/eval/src/orchestrator/mod.rs +++ b/libs/@local/hashql/eval/src/orchestrator/mod.rs @@ -59,7 +59,7 @@ pub use self::{ error::{OrchestratorDiagnostic, OrchestratorDiagnosticCategory}, events::{AppendEventLog, Event, EventLog}, }; -use crate::{context::EvalContext, postgres::PreparedQueries}; +use crate::{context::CodeExecutionContext, postgres::PreparedQueries}; pub mod codec; pub mod error; @@ -119,7 +119,7 @@ impl Deref for Indexed { pub struct Orchestrator<'env, 'ctx, 'heap, C, E, A: Allocator> { client: C, queries: &'env PreparedQueries<'heap, A>, - context: &'env EvalContext<'ctx, 'heap, A>, + context: &'env CodeExecutionContext<'ctx, 'heap, A>, /// Event sink for execution tracing. See [`EventLog`]. pub event_log: E, } @@ -128,7 +128,7 @@ impl<'env, 'ctx, 'heap, C, A: Allocator> Orchestrator<'env, 'ctx, 'heap, C, (), pub const fn new( client: C, queries: &'env PreparedQueries<'heap, A>, - context: &'env EvalContext<'ctx, 'heap, A>, + context: &'env CodeExecutionContext<'ctx, 'heap, A>, ) -> Self { Self { client, diff --git a/libs/@local/hashql/eval/src/orchestrator/partial.rs b/libs/@local/hashql/eval/src/orchestrator/partial.rs index 385c86aa51b..0231885b92c 100644 --- a/libs/@local/hashql/eval/src/orchestrator/partial.rs +++ b/libs/@local/hashql/eval/src/orchestrator/partial.rs @@ -31,7 +31,6 @@ use hashql_core::{ r#type::{TypeId, environment::Environment}, }; use hashql_mir::{ - intern::Interner, interpret::value::{Int, Num, Opaque, StructBuilder, Value}, pass::execution::{ VertexType, @@ -46,7 +45,7 @@ use super::{ codec::{JsonValueRef, decode::Decoder}, error::BridgeError, }; -use crate::postgres::ColumnDescriptor; +use crate::{intern::Interner, postgres::ColumnDescriptor}; macro_rules! hydrate { ($this:ident -> $entry:ident $(-> $field:ident)+ = $value:expr) => { diff --git a/libs/@local/hashql/eval/src/postgres/filter/mod.rs b/libs/@local/hashql/eval/src/postgres/filter/mod.rs index 588d832dd71..b020e109294 100644 --- a/libs/@local/hashql/eval/src/postgres/filter/mod.rs +++ b/libs/@local/hashql/eval/src/postgres/filter/mod.rs @@ -62,7 +62,7 @@ use super::{ traverse::eval_entity_path, types::{IntegerType, integer_type}, }; -use crate::{context::EvalContext, error::EvalDiagnosticIssues}; +use crate::{context::CodeGenerationContext, error::EvalDiagnosticIssues}; /// Internal representation of a continuation result before casting to the SQL composite type. /// @@ -233,7 +233,7 @@ fn finish_switch_int( /// internal buffer retrievable via [`Self::into_diagnostics`]. pub(crate) struct GraphReadFilterCompiler<'ctx, 'heap, A: Allocator = Global, S: Allocator = Global> { - context: &'ctx EvalContext<'ctx, 'heap, A>, + context: &'ctx CodeGenerationContext<'ctx, 'heap, A>, body: &'ctx Body<'heap>, env: Local, @@ -247,7 +247,7 @@ pub(crate) struct GraphReadFilterCompiler<'ctx, 'heap, A: Allocator = Global, S: impl<'ctx, 'heap, A: Allocator, S: Allocator> GraphReadFilterCompiler<'ctx, 'heap, A, S> { pub(crate) fn new( - context: &'ctx EvalContext<'ctx, 'heap, A>, + context: &'ctx CodeGenerationContext<'ctx, 'heap, A>, body: &'ctx Body<'heap>, env: Local, scratch: S, diff --git a/libs/@local/hashql/eval/src/postgres/filter/tests.rs b/libs/@local/hashql/eval/src/postgres/filter/tests.rs index cc7ef967ac3..486ec2b8707 100644 --- a/libs/@local/hashql/eval/src/postgres/filter/tests.rs +++ b/libs/@local/hashql/eval/src/postgres/filter/tests.rs @@ -40,7 +40,7 @@ use sqruff_lib::core::{config::FluffConfig, linter::core::Linter}; use sqruff_lib_core::dialects::init::DialectKind; use crate::{ - context::EvalContext, + context::CodeGenerationContext, postgres::{DatabaseContext, PostgresCompiler, filter::GraphReadFilterCompiler}, }; @@ -48,6 +48,7 @@ use crate::{ /// and returns everything needed for compilation. struct Fixture<'heap> { env: Environment<'heap>, + interner: crate::intern::Interner<'heap>, bodies: DefIdVec, &'heap Heap>, execution: DefIdVec>, &'heap Heap>, } @@ -89,6 +90,7 @@ impl<'heap> Fixture<'heap> { Self { env, + interner: interner.into(), bodies, execution, } @@ -151,11 +153,10 @@ fn format_body<'heap>(fixture: &Fixture<'heap>, heap: &'heap Heap) -> String { fn compile_filter_islands<'heap>(fixture: &Fixture<'heap>, heap: &'heap Heap) -> FilterReport { let mut scratch = Scratch::new(); let def = fixture.def(); - let interner = Interner::new(heap); - let context = EvalContext::new_in( + let context = CodeGenerationContext::new_in( &fixture.env, - &interner, + &fixture.interner, &fixture.bodies, &fixture.execution, heap, @@ -278,11 +279,10 @@ fn compile_full_query_with_mask<'heap>( ) -> QueryReport { let mut scratch = Scratch::new(); let def = fixture.def(); - let interner = Interner::new(heap); - let mut context = EvalContext::new_in( + let mut context = CodeGenerationContext::new_in( &fixture.env, - &interner, + &fixture.interner, &fixture.bodies, &fixture.execution, heap, diff --git a/libs/@local/hashql/eval/src/postgres/mod.rs b/libs/@local/hashql/eval/src/postgres/mod.rs index 870923a1b30..0f18a5f1f00 100644 --- a/libs/@local/hashql/eval/src/postgres/mod.rs +++ b/libs/@local/hashql/eval/src/postgres/mod.rs @@ -64,7 +64,7 @@ pub use self::{ continuation::ContinuationField, parameters::{Parameter, ParameterIndex, ParameterValue, Parameters, TemporalAxis}, }; -use crate::context::EvalContext; +use crate::context::CodeGenerationContext; mod continuation; pub(crate) mod error; @@ -235,12 +235,12 @@ impl<'heap, A: Allocator> PreparedQueries<'heap, A> { /// Compiles Postgres-targeted MIR islands into a single PostgreSQL `SELECT`. /// /// Created per evaluation and used to compile [`GraphRead`] terminators. Compilation emits -/// diagnostics into the shared [`EvalContext`] rather than returning `Result`, so multiple -/// errors can be reported from a single compilation pass. +/// diagnostics into the shared [`CodeGenerationContext`] rather than returning `Result`, so +/// multiple errors can be reported from a single compilation pass. /// /// [`GraphRead`]: hashql_mir::body::terminator::GraphRead pub struct PostgresCompiler<'eval, 'ctx, 'heap, A: Allocator, S: Allocator> { - context: &'eval mut EvalContext<'ctx, 'heap, A>, + context: &'eval mut CodeGenerationContext<'ctx, 'heap, A>, alloc: A, scratch: S, @@ -257,7 +257,7 @@ pub struct PostgresCompiler<'eval, 'ctx, 'heap, A: Allocator, S: Allocator> { impl<'eval, 'ctx, 'heap, A: Allocator, S: BumpAllocator> PostgresCompiler<'eval, 'ctx, 'heap, A, S> { - pub fn new_in(context: &'eval mut EvalContext<'ctx, 'heap, A>, scratch: S) -> Self + pub fn new_in(context: &'eval mut CodeGenerationContext<'ctx, 'heap, A>, scratch: S) -> Self where A: Clone, { diff --git a/libs/@local/hashql/eval/src/postgres/parameters.rs b/libs/@local/hashql/eval/src/postgres/parameters.rs index 411a93544d9..e929f18fb9e 100644 --- a/libs/@local/hashql/eval/src/postgres/parameters.rs +++ b/libs/@local/hashql/eval/src/postgres/parameters.rs @@ -36,6 +36,7 @@ impl Display for ParameterIndex { } impl From for Expression { + #[inline] fn from(value: ParameterIndex) -> Self { Self::Parameter(value.as_usize() + 1) } diff --git a/libs/@local/hashql/eval/src/postgres/projections.rs b/libs/@local/hashql/eval/src/postgres/projections.rs index c422060d19b..66a55aa58f6 100644 --- a/libs/@local/hashql/eval/src/postgres/projections.rs +++ b/libs/@local/hashql/eval/src/postgres/projections.rs @@ -20,6 +20,7 @@ enum ComputedColumn { } impl From for ColumnName<'_> { + #[inline] fn from(value: ComputedColumn) -> Self { match value { ComputedColumn::EntityTypeIds => ColumnName::from(Identifier::from("entity_type_ids")), diff --git a/libs/@local/hashql/eval/tests/orchestrator/execution.rs b/libs/@local/hashql/eval/tests/orchestrator/execution.rs index 09f3dcd74cf..5bad3bba6ab 100644 --- a/libs/@local/hashql/eval/tests/orchestrator/execution.rs +++ b/libs/@local/hashql/eval/tests/orchestrator/execution.rs @@ -1,18 +1,19 @@ -use alloc::alloc::Global; use core::mem; use hashql_compiletest::pipeline::Pipeline; -use hashql_core::{heap::ResetAllocator as _, span::SpanId}; +use hashql_core::{ + heap::{Heap, ResetAllocator as _}, + span::SpanId, +}; use hashql_diagnostics::{Diagnostic, diagnostic::BoxedDiagnostic}; use hashql_eval::{ - context::EvalContext, + context::{CodeExecutionContext, CodeGenerationContext}, orchestrator::{AppendEventLog, Event, Orchestrator}, postgres::PostgresCompiler, }; use hashql_mir::{ body::Body, def::{DefId, DefIdSlice, DefIdVec}, - intern::Interner, interpret::{Inputs, value::Value}, }; use tokio::runtime; @@ -23,7 +24,7 @@ use tokio_postgres::Client; /// Holds the MIR artifacts needed to build typed inputs (via the decoder /// and the environment) before proceeding to execution. pub(crate) struct Lowered<'heap> { - pub interner: Interner<'heap>, + pub interner: hashql_mir::intern::Interner<'heap>, pub entry: DefId, pub bodies: DefIdVec>, } @@ -65,16 +66,16 @@ pub(crate) fn run<'heap>( runtime: &runtime::Runtime, client: &Client, - inputs: &Inputs<'heap, Global>, + inputs: &Inputs<'heap, &'heap Heap>, - lowered: &mut Lowered<'heap>, -) -> Result<(Value<'heap, Global>, Vec), BoxedDiagnostic<'static, SpanId>> { + mut lowered: Lowered<'heap>, +) -> Result<(Value<'heap, &'heap Heap>, Vec), BoxedDiagnostic<'static, SpanId>> { run_impl( pipeline, runtime, client, inputs, - &lowered.interner, + lowered.interner, lowered.entry, &mut lowered.bodies, ) @@ -94,12 +95,12 @@ pub(crate) fn execute<'heap>( runtime: &runtime::Runtime, client: &Client, - inputs: &Inputs<'heap, Global>, + inputs: &Inputs<'heap, &'heap Heap>, - interner: &Interner<'heap>, + interner: hashql_mir::intern::Interner<'heap>, entry: DefId, bodies: &mut DefIdSlice>, -) -> Result<(Value<'heap, Global>, Vec), BoxedDiagnostic<'static, SpanId>> { +) -> Result<(Value<'heap, &'heap Heap>, Vec), BoxedDiagnostic<'static, SpanId>> { run_impl(pipeline, runtime, client, inputs, interner, entry, bodies) } @@ -116,18 +117,19 @@ fn run_impl<'heap>( runtime: &runtime::Runtime, client: &Client, - inputs: &Inputs<'heap, Global>, + inputs: &Inputs<'heap, &'heap Heap>, - interner: &Interner<'heap>, + interner: hashql_mir::intern::Interner<'heap>, entry: DefId, bodies: &mut DefIdSlice>, -) -> Result<(Value<'heap, Global>, Vec), BoxedDiagnostic<'static, SpanId>> { - pipeline.transform(interner, bodies)?; - let analysis = pipeline.prepare(interner, bodies)?; +) -> Result<(Value<'heap, &'heap Heap>, Vec), BoxedDiagnostic<'static, SpanId>> { + pipeline.transform(&interner, bodies)?; + let analysis = pipeline.prepare(&interner, bodies)?; - let mut context = EvalContext::new_in( + let interner = interner.into(); + let mut context = CodeGenerationContext::new_in( &pipeline.env, - interner, + &interner, bodies, &analysis, pipeline.heap, @@ -141,11 +143,12 @@ fn run_impl<'heap>( pipeline.diagnostics.append(&mut diagnostics.boxed()); let event_log = AppendEventLog::new(); + let context = CodeExecutionContext::from(context); let orchestrator = Orchestrator::new(PostgresClient(client), &queries, &context).with_event_log(&event_log); let value = runtime - .block_on(orchestrator.run(inputs, entry, [])) + .block_on(orchestrator.run_in(inputs, entry, [], pipeline.heap)) .map_err(Diagnostic::generalize) .map_err(Diagnostic::boxed)?; diff --git a/libs/@local/hashql/eval/tests/orchestrator/inputs.rs b/libs/@local/hashql/eval/tests/orchestrator/inputs.rs index d2276c909c8..7fe8d0d7160 100644 --- a/libs/@local/hashql/eval/tests/orchestrator/inputs.rs +++ b/libs/@local/hashql/eval/tests/orchestrator/inputs.rs @@ -1,23 +1,19 @@ -use alloc::alloc::Global; +use alloc::rc::Rc; -use hashql_compiletest::pipeline::Pipeline; use hashql_core::{ - heap::Heap, - module::std_lib::graph::types::{ - knowledge::entity, principal::actor_group::web::types as web_types, - }, - symbol::sym, - r#type::TypeBuilder, + heap::{FromIn as _, Heap}, + intern::InternSet, + symbol::{Symbol, sym}, }; -use hashql_eval::orchestrator::codec::{Decoder, JsonValueRef}; -use hashql_mir::{ - intern::Interner, - interpret::{ - Inputs, - value::{self, Value}, - }, +use hashql_mir::interpret::{ + Inputs, + value::{self, StructBuilder, Value}, }; -use type_system::knowledge::entity::id::EntityUuid; +use type_system::{ + knowledge::entity::id::EntityUuid, + principal::actor_group::{ActorGroupEntityUuid, WebId}, +}; +use uuid::Uuid; use crate::{ directives::{AxisBound, AxisDirectives, AxisInterval}, @@ -25,70 +21,75 @@ use crate::{ }; /// Constructs `Opaque(Timestamp, Integer(ms))`. -fn timestamp_value(ms: i128) -> Value<'static, Global> { +fn timestamp_value(heap: &Heap, ms: i128) -> Value<'_, &Heap> { Value::Opaque(value::Opaque::new( sym::path::Timestamp, - Value::Integer(value::Int::from(ms)), + Rc::new_in(Value::Integer(value::Int::from(ms)), heap), )) } /// Constructs `Opaque(UnboundedTemporalBound, Unit)`. -fn unbounded_bound() -> Value<'static, Global> { +fn unbounded_bound(heap: &Heap) -> Value<'_, &Heap> { Value::Opaque(value::Opaque::new( sym::path::UnboundedTemporalBound, - Value::Unit, + Rc::new_in(Value::Unit, heap), )) } /// Constructs `Opaque(ExclusiveTemporalBound, Timestamp(ms))`. -fn exclusive_bound(ms: i128) -> Value<'static, Global> { +fn exclusive_bound(heap: &Heap, ms: i128) -> Value<'_, &Heap> { Value::Opaque(value::Opaque::new( sym::path::ExclusiveTemporalBound, - timestamp_value(ms), + Rc::new_in(timestamp_value(heap, ms), heap), )) } /// Constructs `Opaque(Interval, {end: .., start: ..})`. /// -/// Fields are sorted lexicographically (`end` before `start`). +/// Field order in `push` calls does not matter; [`StructBuilder::finish`] +/// sorts fields lexicographically. fn interval_value<'heap>( - interner: &Interner<'heap>, - start: Value<'heap, Global>, - end: Value<'heap, Global>, -) -> Value<'heap, Global> { - // Fields sorted: "end" < "start" - let fields = interner.symbols.intern_slice(&[sym::end, sym::start]); - let values = vec![end, start]; + heap: &'heap Heap, + symbols: &InternSet<'heap, [Symbol<'heap>]>, + start: Value<'heap, &'heap Heap>, + end: Value<'heap, &'heap Heap>, +) -> Value<'heap, &'heap Heap> { + let mut builder = StructBuilder::<_, 2>::new(); + builder.push(sym::end, end); + builder.push(sym::start, start); + + let inner = builder.finish(symbols, heap); Value::Opaque(value::Opaque::new( sym::path::Interval, - Value::Struct(value::Struct::new(fields, values).expect("interval struct is valid")), + Rc::new_in(Value::Struct(inner), heap), )) } /// Converts an [`AxisInterval`] to a `Value` representing a temporal /// interval: `Opaque(Interval, {start: , end: })`. fn axis_interval_to_value<'heap>( - interner: &Interner<'heap>, + heap: &'heap Heap, + symbols: &InternSet<'heap, [Symbol<'heap>]>, interval: &AxisInterval, -) -> Value<'heap, Global> { +) -> Value<'heap, &'heap Heap> { let start = match interval.start { - AxisBound::Unbounded => unbounded_bound(), + AxisBound::Unbounded => unbounded_bound(heap), AxisBound::Included(ms) => Value::Opaque(value::Opaque::new( sym::path::InclusiveTemporalBound, - timestamp_value(ms), + Rc::new_in(timestamp_value(heap, ms), heap), )), - AxisBound::Excluded(ms) => exclusive_bound(ms), + AxisBound::Excluded(ms) => exclusive_bound(heap, ms), }; let end = match interval.end { - AxisBound::Unbounded => unbounded_bound(), + AxisBound::Unbounded => unbounded_bound(heap), AxisBound::Included(ms) => Value::Opaque(value::Opaque::new( sym::path::InclusiveTemporalBound, - timestamp_value(ms), + Rc::new_in(timestamp_value(heap, ms), heap), )), - AxisBound::Excluded(ms) => exclusive_bound(ms), + AxisBound::Excluded(ms) => exclusive_bound(heap, ms), }; - interval_value(interner, start, end) + interval_value(heap, symbols, start, end) } /// Returns `true` if the interval is a point (both bounds are Included with @@ -108,9 +109,10 @@ fn is_point(interval: &AxisInterval) -> Option { /// determines which axis is pinned (a point `(T)`) and which is variable /// (a range `[a, b)` or defaulting to unbounded). fn temporal_axes_from_directives<'heap>( - interner: &Interner<'heap>, + heap: &'heap Heap, + symbols: &InternSet<'heap, [Symbol<'heap>]>, directives: &AxisDirectives, -) -> Value<'heap, Global> { +) -> Value<'heap, &'heap Heap> { let far_future_ms: i128 = 4_102_444_800_000; // 2100-01-01T00:00:00Z let default_variable = || AxisInterval { start: AxisBound::Unbounded, @@ -167,15 +169,23 @@ fn temporal_axes_from_directives<'heap>( } }; - let pinned = Value::Opaque(value::Opaque::new(pinned_axis, timestamp_value(pinned_ms))); + let pinned = Value::Opaque(value::Opaque::new( + pinned_axis, + Rc::new_in(timestamp_value(heap, pinned_ms), heap), + )); let variable = Value::Opaque(value::Opaque::new( variable_axis_name, - axis_interval_to_value(interner, &variable_interval), + Rc::new_in( + axis_interval_to_value(heap, symbols, &variable_interval), + heap, + ), )); - // "pinned" < "variable" lexicographically. - let fields = interner.symbols.intern_slice(&[sym::pinned, sym::variable]); - let values = vec![pinned, variable]; + let mut builder = value::StructBuilder::<_, 2>::new(); + builder.push(sym::pinned, pinned); + builder.push(sym::variable, variable); + + let inner = builder.finish(symbols, heap); let wrapper_name = if pinned_axis == sym::path::TransactionTime { sym::path::PinnedTransactionTimeTemporalAxes @@ -185,97 +195,145 @@ fn temporal_axes_from_directives<'heap>( Value::Opaque(value::Opaque::new( wrapper_name, - Value::Struct(value::Struct::new(fields, values).expect("axes struct is valid")), + Rc::new_in(Value::Struct(inner), heap), )) } +fn option<'heap, T>( + heap: &'heap Heap, + value: Option, + on_value: impl FnOnce(&'heap Heap, T) -> value::Value<'heap, &'heap Heap>, +) -> value::Value<'heap, &'heap Heap> { + value.map_or_else( + || { + value::Value::Opaque(value::Opaque::new( + sym::path::None, + Rc::new_in(value::Value::Unit, heap), + )) + }, + |value| { + value::Value::Opaque(value::Opaque::new( + sym::path::Some, + Rc::new_in(on_value(heap, value), heap), + )) + }, + ) +} + /// Builds the shared input set from seeded entity data and axis directives. /// -/// Uses the decoder and the post-lowering type environment to construct -/// properly typed `Value`s for entity UUIDs and entity IDs. The input names -/// match what J-Expr test files reference via `["input", "", ""]`. +/// Constructs interpreter [`Value`]s directly from the Rust-typed seed data, +/// mirroring the opaque wrapping structure of the HashQL type system +/// (e.g. `EntityId(Struct { web_id: WebId(ActorGroupEntityUuid(Uuid(String))), ... })`). +/// +/// The input names match what J-Expr test files reference via +/// `["input", "", ""]`. pub(crate) fn build_inputs<'heap>( heap: &'heap Heap, - pipeline: &Pipeline<'heap>, - interner: &Interner<'heap>, + symbols: &InternSet<'heap, [Symbol<'heap>]>, entities: &SeededEntities, directives: &AxisDirectives, -) -> Inputs<'heap, Global> { - let mut inputs = Inputs::new(); - let decoder = Decoder::new(&pipeline.env, interner, Global); - let ty = TypeBuilder::synthetic(&pipeline.env); - let entity_uuid_type = entity::types::entity_uuid(&ty, None); - let entity_id_type = entity::types::entity_id(&ty, None); +) -> Inputs<'heap, &'heap Heap> { + let mut inputs = Inputs::new_in(heap); - // Insert an EntityUuid-typed input. - let insert_uuid = |inputs: &mut Inputs<'heap, Global>, name: &str, uuid: &EntityUuid| { - let uuid_str = uuid.to_string(); - let value = decoder - .decode(entity_uuid_type, JsonValueRef::String(&uuid_str)) - .expect("could not decode EntityUuid input"); + let string = |value: &str| value::Value::String(value::Str::from(Rc::from_in(value, heap))); + + let uuid = |value: Uuid| { + value::Value::Opaque(value::Opaque::new( + sym::path::Uuid, + Rc::new_in(string(value.to_string().as_str()), heap), + )) + }; + + let entity_uuid = |value: EntityUuid| { + value::Value::Opaque(value::Opaque::new( + sym::path::EntityUuid, + Rc::new_in(uuid(value.into()), heap), + )) + }; + + let actor_group_entity_uuid = |value: ActorGroupEntityUuid| { + value::Value::Opaque(value::Opaque::new( + sym::path::ActorGroupEntityUuid, + Rc::new_in(uuid(value.into()), heap), + )) + }; + + let web_id = |value: WebId| { + value::Value::Opaque(value::Opaque::new( + sym::path::WebId, + Rc::new_in(actor_group_entity_uuid(value.into()), heap), + )) + }; - inputs.insert(heap.intern_symbol(name), value); + let draft_id = |value: Option| { + option(heap, value, |heap, value| { + value::Value::Opaque(value::Opaque::new( + sym::path::DraftId, + Rc::new_in(uuid(value.into()), heap), + )) + }) }; + let entity_id = |value: type_system::knowledge::entity::id::EntityId| { + let mut builder = StructBuilder::<_, 3>::new(); + builder.push(sym::web_id, web_id(value.web_id)); + builder.push(sym::entity_uuid, entity_uuid(value.entity_uuid)); + builder.push(sym::draft_id, draft_id(value.draft_id)); + + let r#struct = builder.finish(symbols, heap); + let inner = value::Value::Struct(r#struct); + value::Value::Opaque(value::Opaque::new( + sym::path::EntityId, + Rc::new_in(inner, heap), + )) + }; + + // Insert an EntityUuid-typed input. + let insert_entity_uuid = + |inputs: &mut Inputs<'heap, &'heap Heap>, name: &str, uuid: EntityUuid| { + inputs.insert(heap.intern_symbol(name), entity_uuid(uuid)); + }; + // Insert a full EntityId-typed input. let insert_entity_id = - |inputs: &mut Inputs<'heap, Global>, + |inputs: &mut Inputs<'heap, &'heap Heap>, name: &str, - id: &type_system::knowledge::entity::EntityId| { - let json = serde_json::json!({ - "web_id": id.web_id.to_string(), - "entity_uuid": id.entity_uuid.to_string(), - "draft_id": id.draft_id.map(|draft| draft.to_string()), - }); - let value = decoder - .decode(entity_id_type, JsonValueRef::from(&json)) - .expect("could not decode EntityId input"); - - inputs.insert(heap.intern_symbol(name), value); + id: type_system::knowledge::entity::EntityId| { + inputs.insert(heap.intern_symbol(name), entity_id(id)); }; - insert_uuid(&mut inputs, "alice_uuid", &entities.alice.entity_uuid); - insert_uuid(&mut inputs, "bob_uuid", &entities.bob.entity_uuid); - insert_uuid(&mut inputs, "org_uuid", &entities.organization.entity_uuid); - insert_uuid( + insert_entity_uuid(&mut inputs, "alice_uuid", entities.alice.entity_uuid); + insert_entity_uuid(&mut inputs, "bob_uuid", entities.bob.entity_uuid); + insert_entity_uuid(&mut inputs, "org_uuid", entities.organization.entity_uuid); + insert_entity_uuid( &mut inputs, "friend_link_uuid", - &entities.friend_link.entity_uuid, + entities.friend_link.entity_uuid, ); - insert_uuid( + insert_entity_uuid( &mut inputs, "draft_alice_uuid", - &entities.draft_alice.entity_uuid, + entities.draft_alice.entity_uuid, ); - insert_entity_id(&mut inputs, "alice_id", &entities.alice); - insert_entity_id(&mut inputs, "bob_id", &entities.bob); - insert_entity_id(&mut inputs, "org_id", &entities.organization); - insert_entity_id(&mut inputs, "friend_link_id", &entities.friend_link); - insert_entity_id(&mut inputs, "draft_alice_id", &entities.draft_alice); + insert_entity_id(&mut inputs, "alice_id", entities.alice); + insert_entity_id(&mut inputs, "bob_id", entities.bob); + insert_entity_id(&mut inputs, "org_id", entities.organization); + insert_entity_id(&mut inputs, "friend_link_id", entities.friend_link); + insert_entity_id(&mut inputs, "draft_alice_id", entities.draft_alice); // WebId input (all seeded entities share the same web). - let web_id_type = web_types::web_id(&ty, None); - let web_id_value = decoder - .decode( - web_id_type, - JsonValueRef::String(&entities.alice.web_id.to_string()), - ) - .expect("could not decode WebId input"); - inputs.insert(heap.intern_symbol("web_id"), web_id_value); + inputs.insert(heap.intern_symbol("web_id"), web_id(entities.alice.web_id)); // String inputs for property-based filtering. - let string_type = ty.string(); - let alice_name = decoder - .decode(string_type, JsonValueRef::String("Alice")) - .expect("could not decode string input"); - inputs.insert(heap.intern_symbol("alice_name"), alice_name); + inputs.insert(heap.intern_symbol("alice_name"), string("Alice")); // Temporal axes from directives (or default: unbounded decision time, // far-future transaction pin). inputs.insert( heap.intern_symbol("temporal_axes"), - temporal_axes_from_directives(interner, directives), + temporal_axes_from_directives(heap, symbols, directives), ); inputs diff --git a/libs/@local/hashql/eval/tests/orchestrator/main.rs b/libs/@local/hashql/eval/tests/orchestrator/main.rs index 53d4b4ea7be..05a0d5efc4c 100644 --- a/libs/@local/hashql/eval/tests/orchestrator/main.rs +++ b/libs/@local/hashql/eval/tests/orchestrator/main.rs @@ -108,8 +108,7 @@ fn run_jexpr_test( let heap = Heap::new(); let mut pipeline = Pipeline::new(&heap); - // Lower first so the type environment is populated, then build inputs. - let mut lowered = match execution::lower(&mut pipeline, &bytes) { + let lowered = match execution::lower(&mut pipeline, &bytes) { Ok(lowered) => lowered, Err(diagnostic) => { let rendered = render_failure(&source, &pipeline, &diagnostic); @@ -119,8 +118,7 @@ fn run_jexpr_test( let inputs = build_inputs( &heap, - &pipeline, - &lowered.interner, + &lowered.interner.symbols, &context.entities, &axis_directives, ); @@ -130,7 +128,7 @@ fn run_jexpr_test( runtime, context.store.as_client(), &inputs, - &mut lowered, + lowered, ) { Ok((value, events)) => { let rendered = render_success(&source, &value, &events, &pipeline)?; @@ -162,8 +160,7 @@ fn run_programmatic_test( let inputs = build_inputs( &heap, - &pipeline, - &interner, + &interner.symbols, &context.entities, &AxisDirectives::default(), ); @@ -177,7 +174,7 @@ fn run_programmatic_test( runtime, context.store.as_client(), &inputs, - &interner, + interner, entry, &mut bodies, ) { diff --git a/libs/@local/hashql/eval/tests/orchestrator/output.rs b/libs/@local/hashql/eval/tests/orchestrator/output.rs index e146bf6e13d..5287ad15a3f 100644 --- a/libs/@local/hashql/eval/tests/orchestrator/output.rs +++ b/libs/@local/hashql/eval/tests/orchestrator/output.rs @@ -1,4 +1,4 @@ -use alloc::alloc::Global; +use core::alloc::Allocator; use std::{collections::HashMap, fs, path::Path, sync::LazyLock}; use error_stack::{Report, ResultExt as _}; @@ -108,9 +108,9 @@ fn normalize(input: &str) -> String { /// /// Returns [`TestError::Serialization`] if the value cannot be serialized to /// JSON. -pub(crate) fn render_success( +pub(crate) fn render_success( source: &str, - value: &Value<'_, Global>, + value: &Value<'_, A>, events: &[Event], pipeline: &Pipeline<'_>, ) -> Result> { diff --git a/tests/graph/benches/manual_queries/entity_queries/mod.rs b/tests/graph/benches/manual_queries/entity_queries/mod.rs index d03a19d57b7..41d4d3c5cb7 100644 --- a/tests/graph/benches/manual_queries/entity_queries/mod.rs +++ b/tests/graph/benches/manual_queries/entity_queries/mod.rs @@ -6,8 +6,8 @@ use criterion_macro::criterion; use either::Either; use error_stack::Report; use hash_graph_api::rest::{ - self, ApiConfig, - entity::{EntityQueryOptions, QueryEntitiesRequest, QueryEntitySubgraphRequest}, + ApiConfig, + entity::query::{QueryEntitiesRequest, QueryEntitySubgraphRequest}, }; use hash_graph_postgres_store::{ Environment, load_env, @@ -139,13 +139,11 @@ impl QueryEntitiesQuery<'_, '_, '_> { let modifies_actor_id = !self.settings.parameters.actor_id.is_empty(); let modifies_limit = !self.settings.parameters.limit.is_empty(); - let (query, options) = self.request.into_parts(); - let actor_id = iter::once(self.actor_id) .chain(mem::take(&mut self.settings.parameters.actor_id)) .sorted_by_key(|actor_id| Uuid::from(*actor_id)) .dedup(); - let limit = iter::once(options.limit) + let limit = iter::once(self.request.limit) .chain( mem::take(&mut self.settings.parameters.limit) .into_iter() @@ -165,13 +163,10 @@ impl QueryEntitiesQuery<'_, '_, '_> { ( Self { actor_id, - request: QueryEntitiesRequest::from_parts( - query.clone(), - EntityQueryOptions { - limit, - ..options.clone() - }, - ), + request: QueryEntitiesRequest { + limit, + ..self.request.clone() + }, settings: self.settings.clone(), }, parameters.join(","), @@ -238,13 +233,13 @@ impl QueryEntitySubgraphQuery<'_, '_, '_> { let modifies_limit = !self.settings.parameters.limit.is_empty(); let modifies_graph_resolve_depths = !self.settings.parameters.traversal_params.is_empty(); - let (query, options, traversal_params) = self.request.clone().into_parts(); + let (request, traversal_params) = self.request.clone().into_parts(); let actor_id = iter::once(self.actor_id) .chain(mem::take(&mut self.settings.parameters.actor_id)) .sorted_by_key(|actor_id| Uuid::from(*actor_id)) .dedup(); - let limit = iter::once(options.limit) + let limit = iter::once(request.limit) .chain( mem::take(&mut self.settings.parameters.limit) .into_iter() @@ -252,6 +247,10 @@ impl QueryEntitySubgraphQuery<'_, '_, '_> { ) .sorted() .dedup(); + let include_count = iter::once(request.include_count) + .chain(mem::take(&mut self.settings.parameters.include_count)) + .sorted() + .dedup(); let traversal_params_iter = iter::once(traversal_params) .chain(mem::take(&mut self.settings.parameters.traversal_params)); @@ -271,10 +270,9 @@ impl QueryEntitySubgraphQuery<'_, '_, '_> { Self { actor_id, request: QueryEntitySubgraphRequest::from_parts( - query.clone(), - EntityQueryOptions { + QueryEntitiesRequest { limit, - ..options.clone() + ..request.clone() }, traversal_params, ), @@ -320,33 +318,19 @@ where match request { GraphQuery::QueryEntities(request) => { - let (query, options) = request.request.into_parts(); - let rest::entity::EntityQuery::Filter { filter } = query else { - panic!("unsupported query type") - }; - let _response = store .query_entities( request.actor_id, - options - .into_params(filter, config) - .expect("limit should not exceed configured maximum"), + request.request.into_params_unchecked(config, None), ) .await .expect("failed to read entities from store"); } GraphQuery::QueryEntitySubgraph(request) => { - let (query, options, traversal) = request.request.into_parts(); - let rest::entity::EntityQuery::Filter { filter } = query else { - panic!("unsupported query type") - }; - let _response = store .query_entity_subgraph( request.actor_id, - options - .into_traversal_params(filter, traversal, config) - .expect("limit should not exceed configured maximum"), + request.request.into_traversal_params_unchecked(config), ) .await .expect("failed to read entity subgraph from store");