diff --git a/Cargo.lock b/Cargo.lock index 5a3d5a1..6d7682c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,6 +44,16 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "astral-tokio-tar" version = "0.5.6" @@ -663,11 +673,13 @@ dependencies = [ "bon", "chrono", "mockall", + "reqwest", "serde", "serde_json", "thiserror 2.0.18", "tokio", "uuid", + "wiremock", ] [[package]] @@ -710,6 +722,7 @@ dependencies = [ "anyhow", "async-trait", "axum", + "bon", "dataplane-sdk", "dataplane-sdk-axum", "dataplane-sdk-postgres", @@ -724,6 +737,24 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" + [[package]] name = "der" version = "0.7.10" @@ -1243,6 +1274,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -1923,6 +1960,16 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -4510,6 +4557,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "wiremock" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" +dependencies = [ + "assert-json-diff", + "base64 0.22.1", + "deadpool", + "futures", + "http", + "http-body-util", + "hyper", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", + "url", +] + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/crates/sdk-axum/src/api.rs b/crates/sdk-axum/src/api.rs index cd92ff5..1b5617d 100644 --- a/crates/sdk-axum/src/api.rs +++ b/crates/sdk-axum/src/api.rs @@ -13,11 +13,13 @@ use axum::{ Extension, Json, extract::{Path, State}, + http::StatusCode, }; use dataplane_sdk::{ core::{ db::tx::TransactionalContext, model::{ + data_flow::DataFlowState, messages::{ DataFlowPrepareMessage, DataFlowResumeMessage, DataFlowStartMessage, DataFlowStartedNotificationMessage, DataFlowStatusMessage, @@ -48,12 +50,18 @@ pub async fn prepare_flow( State(sdk): State>, Extension(participant): Extension, Json(msg): Json, -) -> SignalingResult> +) -> SignalingResult<(StatusCode, Json)> where C: TransactionalContext, { let response = sdk.prepare(&participant.id, msg).await?; - Ok(Json(response)) + + let status = match response.state { + DataFlowState::Preparing => StatusCode::OK, + _ => StatusCode::OK, + }; + + Ok((status, Json(response))) } #[derive(Deserialize)] diff --git a/crates/sdk-postgres/migrations/20260218164640_create_data_flows.sql b/crates/sdk-postgres/migrations/20260218164640_create_data_flows.sql index 7f14f0d..7de62f6 100644 --- a/crates/sdk-postgres/migrations/20260218164640_create_data_flows.sql +++ b/crates/sdk-postgres/migrations/20260218164640_create_data_flows.sql @@ -1,4 +1,4 @@ -CREATE TYPE data_flow_state AS ENUM ('started','suspended','terminated','completed', 'initiating', 'initiated', 'preparing', 'prepared'); +CREATE TYPE data_flow_state AS ENUM ('starting','started','suspended','terminated','completed', 'initiating', 'initiated', 'preparing', 'prepared'); CREATE TYPE data_flow_type AS ENUM ('consumer','provider'); diff --git a/crates/sdk-tck-tests/Cargo.toml b/crates/sdk-tck-tests/Cargo.toml index 3c93a6b..643ba5d 100644 --- a/crates/sdk-tck-tests/Cargo.toml +++ b/crates/sdk-tck-tests/Cargo.toml @@ -21,6 +21,7 @@ axum.workspace=true tower-http = { version = "0.6.9", features = ["trace"]} futures.workspace=true regex.workspace=true +bon.workspace=true [lints] workspace = true diff --git a/crates/sdk-tck-tests/tests/dps.tck.properties b/crates/sdk-tck-tests/tests/dps.tck.properties index a8283b6..491954e 100644 --- a/crates/sdk-tck-tests/tests/dps.tck.properties +++ b/crates/sdk-tck-tests/tests/dps.tck.properties @@ -11,13 +11,18 @@ DP_P_PULL_01_01_TRANSFERTYPE=http_pull_sync DP_P_PULL_01_02_TRANSFERTYPE=http_pull_sync DP_P_PULL_02_01_TRANSFERTYPE=http_pull_sync DP_P_PULL_02_02_TRANSFERTYPE=http_pull_sync - +DP_P_PULL_04_01_TRANSFERTYPE=http_pull_async +DP_P_PULL_04_01_AGREEMENTID=initiating-started # pull consumer side DP_C_PULL_01_01_TRANSFERTYPE=http_pull_sync DP_C_PULL_01_02_TRANSFERTYPE=http_pull_sync DP_C_PULL_02_01_TRANSFERTYPE=http_pull_sync DP_C_PULL_02_02_TRANSFERTYPE=http_pull_sync +DP_C_PULL_03_01_TRANSFERTYPE=http_pull_sync +DP_C_PULL_03_01_AGREEMENTID=prepared-completed +DP_C_PULL_04_01_TRANSFERTYPE=http_pull_async +DP_C_PULL_04_01_AGREEMENTID=initiating-prepared # push consumer side DP_C_PUSH_01_01_TRANSFERTYPE=http_push_sync @@ -25,10 +30,15 @@ DP_C_PUSH_01_02_TRANSFERTYPE=http_push_sync DP_C_PUSH_02_01_TRANSFERTYPE=http_push_sync DP_C_PUSH_02_02_TRANSFERTYPE=http_push_sync DP_C_PUSH_03_01_TRANSFERTYPE=http_push_sync +DP_C_PUSH_04_01_TRANSFERTYPE=http_push_async +DP_C_PUSH_04_01_AGREEMENTID=initiating-prepared # push provider side DP_P_PUSH_01_01_TRANSFERTYPE=http_push_sync DP_P_PUSH_01_02_TRANSFERTYPE=http_push_sync DP_P_PUSH_02_01_TRANSFERTYPE=http_push_sync DP_P_PUSH_02_02_TRANSFERTYPE=http_push_sync - +DP_P_PUSH_03_01_TRANSFERTYPE=http_push_sync +DP_P_PUSH_03_01_AGREEMENTID=initiating-completed +DP_P_PUSH_04_01_TRANSFERTYPE=http_push_async +DP_P_PUSH_04_01_AGREEMENTID=initiating-started diff --git a/crates/sdk-tck-tests/tests/tck_tests.rs b/crates/sdk-tck-tests/tests/tck_tests.rs index 0ae2027..9b1a780 100644 --- a/crates/sdk-tck-tests/tests/tck_tests.rs +++ b/crates/sdk-tck-tests/tests/tck_tests.rs @@ -10,8 +10,9 @@ // Metaform Systems, Inc. - initial API and implementation // -use std::collections::HashMap; +use std::{collections::HashMap, time::Duration}; +use bon::Builder; use dataplane_sdk::core::{ error::{HandlerError, HandlerResult}, handler::DataFlowHandler, @@ -22,24 +23,62 @@ use dataplane_sdk::core::{ }, }; use dataplane_sdk_postgres::PgTransaction; +use tokio::sync::mpsc::Sender; use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; mod util; #[cfg(test)] mod tck_tests { + use dataplane_sdk::{core::db::tx::TransactionalContext, sdk::DataPlaneSdk}; + use tokio::sync::mpsc::Receiver; + use crate::util::TckTestReporter; use super::*; - static EXPECTED_FAILURES: &[&str] = &[ - "DP_P_PULL:04-01", - "DP_C_PULL:03-01", - "DP_C_PULL:04-01", - "DP_C_PUSH:04-01", - "DP_P_PUSH:03-01", - "DP_P_PUSH:04-01", - ]; + fn handle_notifications( + sdk: DataPlaneSdk, + rx: Receiver, + ) where + ::Transaction: std::marker::Send, + { + tokio::task::spawn(async move { + let mut stream = rx; + while let Some(notification) = stream.recv().await { + tokio::time::sleep(Duration::from_millis(250)).await; + let flow = ¬ification.flow; + match notification.kind { + NotificationKind::Started => { + sdk.notify_started(&flow.participant_context_id, &flow.id, None) + .await + .unwrap(); + } + NotificationKind::Suspended => tracing::info!( + "Flow suspended: ID: {}, state: {:?}", + notification.flow.id, + notification.flow.state + ), + NotificationKind::Completed => { + sdk.notify_completed(&flow.participant_context_id, &flow.id) + .await + .unwrap(); + } + NotificationKind::Errored(err) => tracing::error!( + "Flow errored: ID: {}, state: {:?}, error: {}", + notification.flow.id, + notification.flow.state, + err + ), + NotificationKind::Prepared => { + sdk.notify_prepared(&flow.participant_context_id, &flow.id, None) + .await + .unwrap(); + } + } + } + }); + } #[tokio::test] async fn dataplane_tck_test() { @@ -48,7 +87,12 @@ mod tck_tests { .with(tracing_subscriber::fmt::layer()) .init(); let (ctx, repo, _container) = util::setup_postgres_container().await; - let sdk = util::sdk(ctx, repo, TckTestHandler::new()).await; + + let (tx, rx) = tokio::sync::mpsc::channel(100); + + let sdk = util::sdk(ctx, repo, TckTestHandler::new(tx)).await; + + handle_notifications(sdk.clone(), rx); util::start_signaling(8282, sdk).await; @@ -56,8 +100,7 @@ mod tck_tests { let _tck_container = util::setup_tck_container(reporter.clone()).await; - let mut failures = reporter.failures(); - failures.retain(|f| !EXPECTED_FAILURES.contains(&f.as_str())); + let failures = reporter.failures(); assert!( failures.is_empty(), @@ -72,24 +115,50 @@ fn env_filter() -> EnvFilter { .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()) } -pub type Action = Box HandlerResult + Send + Sync>; +pub type Handler = Box HandlerResult + Send + Sync>; + +#[derive(Builder, Clone)] +pub struct Notification { + flow: DataFlow, + kind: NotificationKind, +} + +#[derive(Clone, Debug)] +pub enum NotificationKind { + Started, + Suspended, + Completed, + Prepared, + Errored(String), +} pub struct TckTestHandler { - actions: HashMap, + handlers: HashMap, + sender: Sender, } impl TckTestHandler { - pub fn new() -> Self { - let mut actions: HashMap = HashMap::new(); + pub fn new(sender: Sender) -> Self { + let mut handlers: HashMap = HashMap::new(); - actions.insert( + handlers.insert( "http_pull_sync".to_string(), Box::new(|flow| http_pull_sync(flow)), ); - actions.insert( + handlers.insert( "http_push_sync".to_string(), Box::new(|flow| http_push_sync(flow)), ); - Self { actions } + + handlers.insert( + "http_pull_async".to_string(), + Box::new(|flow| http_pull_async(flow)), + ); + + handlers.insert( + "http_push_async".to_string(), + Box::new(|flow| http_push_async(flow)), + ); + Self { handlers, sender } } } @@ -106,7 +175,8 @@ impl DataFlowHandler for TckTestHandler { tx: &mut Self::Transaction, flow: &DataFlow, ) -> HandlerResult { - self.actions + self.fire_notification(flow).await; + self.handlers .get(&flow.transfer_type) .map(|action| action(flow)) .ok_or_else(|| { @@ -122,7 +192,8 @@ impl DataFlowHandler for TckTestHandler { tx: &mut Self::Transaction, flow: &DataFlow, ) -> HandlerResult { - self.actions + self.fire_notification(flow).await; + self.handlers .get(&flow.transfer_type) .map(|action| action(flow)) .ok_or_else(|| { @@ -134,18 +205,57 @@ impl DataFlowHandler for TckTestHandler { } async fn on_terminate(&self, tx: &mut Self::Transaction, flow: &DataFlow) -> HandlerResult<()> { + self.fire_notification(flow).await; Ok(()) } async fn on_started(&self, tx: &mut Self::Transaction, flow: &DataFlow) -> HandlerResult<()> { + self.fire_notification(flow).await; Ok(()) } async fn on_suspend(&self, tx: &mut Self::Transaction, flow: &DataFlow) -> HandlerResult<()> { + self.fire_notification(flow).await; Ok(()) } } +fn matches_state(state: &str, expected: DataFlowState) -> bool { + match expected { + DataFlowState::Prepared => state == "prepared", + DataFlowState::Started => state == "started", + DataFlowState::Initiating => state == "initiating", + _ => false, + } +} + +impl TckTestHandler { + async fn fire_notification(&self, flow: &DataFlow) { + let notification = flow.agreement_id.split("-").collect::>(); + let kind = match notification.as_slice() { + [state, "completed"] if matches_state(state, flow.state.clone()) => { + NotificationKind::Completed + } + [state, "prepared"] if matches_state(state, flow.state.clone()) => { + NotificationKind::Prepared + } + [state, "started"] if matches_state(state, flow.state.clone()) => { + NotificationKind::Started + } + [state, "error"] if matches_state(state, flow.state.clone()) => { + NotificationKind::Errored("Simulated error".to_string()) + } + _ => return, // No notification for other agreement IDs + }; + let notification = Notification::builder() + .flow(flow.clone()) + .kind(kind) + .build(); + + self.sender.send(notification).await.unwrap(); + } +} + pub fn http_pull_sync(flow: &DataFlow) -> HandlerResult { let (data_address, state) = match flow.kind { DataFlowType::Consumer => (None, DataFlowState::Prepared), @@ -166,6 +276,30 @@ pub fn http_pull_sync(flow: &DataFlow) -> HandlerResult { .build()) } +pub fn http_pull_async(flow: &DataFlow) -> HandlerResult { + let (data_address, state) = match flow.kind { + DataFlowType::Consumer => (None, DataFlowState::Preparing), + DataFlowType::Provider => (None, DataFlowState::Starting), + }; + Ok(DataFlowStatusMessage::builder() + .data_flow_id(flow.id.clone()) + .maybe_data_address(data_address) + .state(state) + .build()) +} + +pub fn http_push_async(flow: &DataFlow) -> HandlerResult { + let (data_address, state) = match flow.kind { + DataFlowType::Consumer => (None, DataFlowState::Preparing), + DataFlowType::Provider => (None, DataFlowState::Starting), + }; + Ok(DataFlowStatusMessage::builder() + .data_flow_id(flow.id.clone()) + .maybe_data_address(data_address) + .state(state) + .build()) +} + pub fn http_push_sync(flow: &DataFlow) -> HandlerResult { let (data_address, state) = match flow.kind { DataFlowType::Consumer => ( diff --git a/crates/sdk-tck-tests/tests/util.rs b/crates/sdk-tck-tests/tests/util.rs index fe73331..e62a517 100644 --- a/crates/sdk-tck-tests/tests/util.rs +++ b/crates/sdk-tck-tests/tests/util.rs @@ -36,7 +36,7 @@ use futures::future::BoxFuture; use regex::Regex; use testcontainers::{ GenericImage, ImageExt, - core::{Host, Mount, WaitFor, logs::consumer::LogConsumer}, + core::{ContainerPort, Host, IntoContainerPort, Mount, WaitFor, logs::consumer::LogConsumer}, runners::AsyncRunner, }; use testcontainers_modules::postgres::Postgres; @@ -88,7 +88,9 @@ pub async fn setup_tck_container( ) -> testcontainers::ContainerAsync { let path = Path::new("tests/dps.tck.properties"); GenericImage::new("eclipsedataspacetck/dps-tck-runtime", "1.1.2") + .with_exposed_port(8083.tcp()) .with_wait_for(WaitFor::message_on_stdout("Test run complete")) + .with_mapped_port(8083, ContainerPort::Tcp(8083)) .with_mount(Mount::bind_mount( path::absolute(path).unwrap().as_os_str().to_str().unwrap(), "/etc/tck/config.properties", diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index ec5656c..7b9186f 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -13,15 +13,17 @@ bon.workspace=true chrono.workspace=true thiserror.workspace=true async-trait.workspace=true -uuid = { workspace=true, optional = true} +uuid = { workspace=true } +reqwest.workspace=true [features] -test = ["dep:uuid"] +test = [] [dev-dependencies] mockall.workspace=true tokio.workspace=true +wiremock = "0.6" [lints] workspace = true diff --git a/crates/sdk/src/core/model/data_flow.rs b/crates/sdk/src/core/model/data_flow.rs index 81f1885..eed006c 100644 --- a/crates/sdk/src/core/model/data_flow.rs +++ b/crates/sdk/src/core/model/data_flow.rs @@ -108,7 +108,7 @@ impl DataFlow { pub fn transition_to_starting(&mut self) -> Result<(), TransitionError> { match self.state { - DataFlowState::Prepared => { + DataFlowState::Initiating | DataFlowState::Prepared => { self.state = DataFlowState::Starting; self.updated_at = chrono::Utc::now(); Ok(()) diff --git a/crates/sdk/src/core/model/messages.rs b/crates/sdk/src/core/model/messages.rs index 73a86fd..d1390bc 100644 --- a/crates/sdk/src/core/model/messages.rs +++ b/crates/sdk/src/core/model/messages.rs @@ -72,12 +72,21 @@ pub struct DataFlowPrepareMessage { #[serde(rename_all = "camelCase")] #[builder(on(String, into))] pub struct DataFlowStatusMessage { + #[builder(default = new_message_id())] + #[serde(default = "new_message_id")] + pub message_id: String, pub data_flow_id: String, + #[serde(skip_serializing_if = "Option::is_none")] pub data_address: Option, pub state: DataFlowState, + #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } +fn new_message_id() -> String { + uuid::Uuid::new_v4().to_string() +} + #[derive(Debug, Builder, Serialize, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct DataFlowStartedNotificationMessage { diff --git a/crates/sdk/src/error.rs b/crates/sdk/src/error.rs index a25a494..a8879c6 100644 --- a/crates/sdk/src/error.rs +++ b/crates/sdk/src/error.rs @@ -25,4 +25,8 @@ pub enum SdkError { Repo(#[from] DbError), #[error("Transition error: {0}")] Transition(#[from] TransitionError), + #[error("Notification transport error: {0}")] + Notification(#[from] reqwest::Error), + #[error("Notification rejected with status {status}: {body}")] + NotificationStatus { status: u16, body: String }, } diff --git a/crates/sdk/src/sdk.rs b/crates/sdk/src/sdk.rs index 0680334..416847c 100644 --- a/crates/sdk/src/sdk.rs +++ b/crates/sdk/src/sdk.rs @@ -41,11 +41,13 @@ where ctx: C, repo: Box>, handler: Box>, + client: reqwest::Client, ) -> Self { Self(Arc::new(internal::DataPlaneSdkInternal { ctx, repo, handler, + client, })) } } @@ -68,6 +70,7 @@ where ctx: C, repo: Option>>, handler: Option>>, + client: Option, } impl DataPlaneSdkBuilder @@ -79,6 +82,7 @@ where ctx, repo: None, handler: None, + client: None, } } @@ -98,11 +102,20 @@ where self } + /// Overrides the `reqwest::Client` used to send control plane notifications. + /// Defaults to `reqwest::Client::new()` when not set. + pub fn with_client(mut self, client: reqwest::Client) -> Self { + self.client = Some(client); + self + } + pub fn build(self) -> Result, String> { let repo = self.repo.ok_or("DataFlowRepo is not set")?; let handler = self.handler.ok_or("DataFlowHandler is not set")?; - Ok(DataPlaneSdk::new(self.ctx, repo, handler)) + let client = self.client.unwrap_or_default(); + + Ok(DataPlaneSdk::new(self.ctx, repo, handler, client)) } } diff --git a/crates/sdk/src/sdk/internal.rs b/crates/sdk/src/sdk/internal.rs index 1d0a6d2..051e9a2 100644 --- a/crates/sdk/src/sdk/internal.rs +++ b/crates/sdk/src/sdk/internal.rs @@ -19,7 +19,8 @@ use crate::{ error::{DbError, HandlerError}, handler::DataFlowHandler, model::{ - data_flow::{DataFlow, DataFlowState, DataFlowType}, + data_address::DataAddress, + data_flow::{DataFlow, DataFlowState, DataFlowType, TransitionError}, messages::{ DataFlowPrepareMessage, DataFlowResumeMessage, DataFlowStartMessage, DataFlowStartedNotificationMessage, DataFlowStatusMessage, @@ -37,6 +38,7 @@ where pub(crate) ctx: C, pub(crate) repo: Box>, pub(crate) handler: Box>, + pub(crate) client: reqwest::Client, } impl DataPlaneSdkInternal @@ -279,6 +281,120 @@ where .build()) } + /// Notifies the control plane that the data flow has been prepared, by + /// POSTing to the flow's `callback_address`. See the Data Plane Signalling + /// "Control Plane Endpoint" section. + pub async fn notify_prepared( + &self, + ctx: &str, + flow_id: &str, + data_address: Option, + ) -> SdkResult<()> { + self.send_callback( + ctx, + flow_id, + "prepared", + data_address.clone(), + None, + move |flow| { + flow.data_address = data_address; + flow.transition_to_prepared() + }, + ) + .await + } + + /// Notifies the control plane that the data flow has started. + pub async fn notify_started( + &self, + ctx: &str, + flow_id: &str, + data_address: Option, + ) -> SdkResult<()> { + self.send_callback( + ctx, + flow_id, + "started", + data_address.clone(), + None, + |flow| { + flow.data_address = data_address; + flow.transition_to_started() + }, + ) + .await + } + + /// Notifies the control plane that the data flow has completed. + pub async fn notify_completed(&self, ctx: &str, flow_id: &str) -> SdkResult<()> { + self.send_callback(ctx, flow_id, "completed", None, None, |flow| { + flow.transition_to_completed() + }) + .await + } + + /// Notifies the control plane that the data flow has errored. The optional + /// `error` is forwarded as the `error` field of the status message. + pub async fn notify_errored( + &self, + ctx: &str, + flow_id: &str, + error: Option, + ) -> SdkResult<()> { + self.send_callback(ctx, flow_id, "errored", None, error.clone(), move |flow| { + flow.transition_to_terminated(error) + }) + .await + } + + async fn send_callback( + &self, + _ctx: &str, + flow_id: &str, + operation: &str, + data_address: Option, + error: Option, + op: CB, + ) -> SdkResult<()> + where + CB: FnOnce(&mut DataFlow) -> Result<(), TransitionError>, + { + let mut tx = self.ctx.begin().await?; + + let mut flow = self + .fetch_by_id(&mut tx, flow_id) + .await? + .ok_or_else(|| DbError::NotFound(flow_id.to_string()))?; + + op(&mut flow)?; + + let msg = DataFlowStatusMessage::builder() + .data_flow_id(flow.id.clone()) + .maybe_data_address(data_address) + .state(flow.state.clone()) + .maybe_error(error) + .build(); + + let url = format!( + "{}/transfers/{}/dataflow/{}", + flow.callback_address.trim_end_matches('/'), + flow.id, + operation + ); + + let resp = self.client.post(&url).json(&msg).send().await?; + if !resp.status().is_success() { + let status = resp.status().as_u16(); + let body = resp.text().await.unwrap_or_default(); + return Err(SdkError::NotificationStatus { status, body }); + } + + self.repo.update(&mut tx, &flow).await?; + tx.commit().await?; + + Ok(()) + } + pub async fn fetch_by_id( &self, tx: &mut C::Transaction, diff --git a/crates/sdk/src/sdk_test.rs b/crates/sdk/src/sdk_test.rs index 1ca908b..38c51eb 100644 --- a/crates/sdk/src/sdk_test.rs +++ b/crates/sdk/src/sdk_test.rs @@ -841,6 +841,181 @@ mod status { } } +mod notify { + use std::future; + + use wiremock::{ + Mock, MockServer, ResponseTemplate, + matchers::{body_string_contains, method, path}, + }; + + use crate::{ + core::{ + db::tx::{MockTransaction, MockTransactionalContext}, + error::DbError, + model::data_flow::{DataFlow, DataFlowState}, + }, + error::SdkError, + sdk::DataPlaneSdk, + sdk_test::{context, flow}, + }; + + fn flow_with(callback: &str, state: DataFlowState) -> DataFlow { + let mut f = flow(); + f.callback_address = callback.to_string(); + f.state = state; + f + } + + /// Configures the mock context so that `begin`/`commit` succeed and + /// `fetch_by_id` returns the supplied flow. + fn with_flow(f: DataFlow) -> DataPlaneSdk { + let (mut ctx, mut repo, handler) = context(); + + ctx.expect_begin().returning(|| { + let mut tx = MockTransaction::new(); + tx.expect_commit() + .returning(|| Box::pin(future::ready(Ok(())))); + Box::pin(future::ready(Ok(tx))) + }); + + repo.expect_fetch_by_id() + .returning(move |_, _| Box::pin(future::ready(Ok(Some(f.clone()))))); + + repo.expect_update() + .returning(|_, _| Box::pin(future::ready(Ok(())))); + + DataPlaneSdk::builder(ctx) + .with_repo(repo) + .with_handler(handler) + .build() + .unwrap() + } + + #[tokio::test] + async fn notify_started_posts_to_callback() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/transfers/flow-id/dataflow/started")) + .and(body_string_contains("\"state\":\"STARTED\"")) + .and(body_string_contains("\"dataFlowId\":\"flow-id\"")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let sdk = with_flow(flow_with(&server.uri(), DataFlowState::Started)); + + let response = sdk.notify_started("participant", "flow-id", None).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn notify_prepared_posts_to_callback() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/transfers/flow-id/dataflow/prepared")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let sdk = with_flow(flow_with(&server.uri(), DataFlowState::Prepared)); + + assert!( + sdk.notify_prepared("participant", "flow-id", None) + .await + .is_ok() + ); + } + + #[tokio::test] + async fn notify_completed_posts_to_callback() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/transfers/flow-id/dataflow/completed")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let sdk = with_flow(flow_with(&server.uri(), DataFlowState::Completed)); + + assert!(sdk.notify_completed("participant", "flow-id").await.is_ok()); + } + + #[tokio::test] + async fn notify_errored_forwards_error() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/transfers/flow-id/dataflow/errored")) + .and(body_string_contains("something broke")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let sdk = with_flow(flow_with(&server.uri(), DataFlowState::Started)); + + let response = sdk + .notify_errored( + "participant", + "flow-id", + Some("something broke".to_string()), + ) + .await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn notify_non_success_status_is_error() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .respond_with(ResponseTemplate::new(500).set_body_string("boom")) + .mount(&server) + .await; + + let sdk = with_flow(flow_with(&server.uri(), DataFlowState::Started)); + + let response = sdk.notify_started("participant", "flow-id", None).await; + + assert!(matches!( + response, + Err(SdkError::NotificationStatus { status: 500, .. }) + )); + } + + #[tokio::test] + async fn notify_not_found() { + let (mut ctx, mut repo, handler) = context(); + + ctx.expect_begin().returning(|| { + let mut tx = MockTransaction::new(); + tx.expect_commit() + .returning(|| Box::pin(future::ready(Ok(())))); + Box::pin(future::ready(Ok(tx))) + }); + + repo.expect_fetch_by_id() + .returning(|_, _| Box::pin(future::ready(Ok(None)))); + + let sdk = DataPlaneSdk::builder(ctx) + .with_repo(repo) + .with_handler(handler) + .build() + .unwrap(); + + let response = sdk.notify_started("participant", "flow-id", None).await; + + assert!(matches!( + response, + Err(SdkError::Repo(DbError::NotFound(_))) + )); + } +} + fn start_message() -> DataFlowStartMessage { DataFlowStartMessage::builder() .process_id("process-id")