diff --git a/Cargo.lock b/Cargo.lock index 209e53361..876422715 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2852,6 +2852,7 @@ dependencies = [ "ironrdp-dvc", "ironrdp-echo", "ironrdp-egfx", + "ironrdp-error", "ironrdp-graphics", "ironrdp-nscodec", "ironrdp-pdu", diff --git a/crates/ironrdp-server/Cargo.toml b/crates/ironrdp-server/Cargo.toml index c01b60668..1be7d6105 100644 --- a/crates/ironrdp-server/Cargo.toml +++ b/crates/ironrdp-server/Cargo.toml @@ -39,8 +39,9 @@ tokio-rustls = "0.26" # public async-trait = "0.1" ironrdp-async = { path = "../ironrdp-async", version = "0.9" } ironrdp-ainput = { path = "../ironrdp-ainput", version = "0.7" } -ironrdp-core = { path = "../ironrdp-core", version = "0.2" } +ironrdp-core = { path = "../ironrdp-core", version = "0.2" } # public ironrdp-egfx = { path = "../ironrdp-egfx", version = "0.2", optional = true } +ironrdp-error = { path = "../ironrdp-error", version = "0.2", features = ["std"] } # public ironrdp-nscodec = { path = "../ironrdp-nscodec", version = "0.1", optional = true, features = ["encoder"] } ironrdp-pdu = { path = "../ironrdp-pdu", version = "0.8" } # public ironrdp-svc = { path = "../ironrdp-svc", version = "0.7" } # public diff --git a/crates/ironrdp-server/src/echo.rs b/crates/ironrdp-server/src/echo.rs index 8caf69b3d..513b5a590 100644 --- a/crates/ironrdp-server/src/echo.rs +++ b/crates/ironrdp-server/src/echo.rs @@ -3,13 +3,13 @@ use std::collections::{BTreeMap, VecDeque}; use std::sync::{Arc, Mutex, MutexGuard}; use std::time::Instant; -use anyhow::{Context as _, Result, bail}; use ironrdp_core::impl_as_any; use ironrdp_dvc::{DvcMessage, DvcProcessor, DvcServerProcessor}; use ironrdp_echo::server::EchoServer; use ironrdp_pdu::PduResult; use tokio::sync::mpsc; +use crate::error::{ServerError, ServerErrorExt as _, ServerResult}; use crate::server::ServerEvent; #[derive(Debug, Clone)] @@ -55,14 +55,14 @@ impl EchoServerHandle { /// Sends a runtime ECHO request. /// /// The payload must be at least one byte, as required by MS-RDPEECO section 3.1.5.1. - pub fn send_request(&self, payload: Vec) -> Result<()> { + pub fn send_request(&self, payload: Vec) -> ServerResult<()> { if payload.is_empty() { - bail!("echoRequest payload must be at least one byte"); + return Err(ServerError::reason("echo request", "payload must be at least one byte")); } self.sender .send(ServerEvent::Echo(EchoServerMessage::SendRequest { payload })) - .map_err(|_error| anyhow::anyhow!("send ECHO request event")) + .map_err(|_error| ServerError::channel("send ECHO request event")) } /// Drains collected RTT measurements. @@ -151,6 +151,6 @@ impl DvcProcessor for EchoDvcBridge { impl DvcServerProcessor for EchoDvcBridge {} -pub(crate) fn build_echo_request(payload: Vec) -> Result { - EchoServer::request_message(payload).context("build ECHO request message") +pub(crate) fn build_echo_request(payload: Vec) -> ServerResult { + EchoServer::request_message(payload).map_err(|e| ServerError::custom("build ECHO request message", e)) } diff --git a/crates/ironrdp-server/src/encoder/mod.rs b/crates/ironrdp-server/src/encoder/mod.rs index d3f83c9e9..e4f8a7212 100644 --- a/crates/ironrdp-server/src/encoder/mod.rs +++ b/crates/ironrdp-server/src/encoder/mod.rs @@ -1,7 +1,7 @@ use core::fmt; use core::num::NonZeroU16; -use anyhow::{Context as _, Result, anyhow}; +use crate::error::{ServerError, ServerErrorExt as _, ServerResult, ServerResultExt as _}; use ironrdp_acceptor::DesktopSize; use ironrdp_graphics::diff::{Rect, find_different_rects_sub}; use ironrdp_pdu::encode_vec; @@ -128,12 +128,12 @@ impl UpdateEncoder { surface_flags: CmdFlags, codecs: UpdateEncoderCodecs, max_request_size: u32, - ) -> Result { + ) -> ServerResult { let bitmap_updater = if surface_flags.contains(CmdFlags::SET_SURFACE_BITS) { match codecs { #[cfg(feature = "qoiz")] UpdateEncoderCodecs { qoiz: Some(id), .. } => { - BitmapUpdater::Qoiz(QoizHandler::new(id).context("failed to initialize qoiz handler")?) + BitmapUpdater::Qoiz(QoizHandler::new(id).with_context("failed to initialize qoiz handler")?) } #[cfg(feature = "qoi")] UpdateEncoderCodecs { qoi: Some(id), .. } => BitmapUpdater::Qoi(QoiHandler::new(id)), @@ -162,7 +162,8 @@ impl UpdateEncoder { desktop_size, framebuffer: None, bitmap_updater: Some(bitmap_updater), - max_request_size: usize::try_from(max_request_size).context("max_request_size")?, + max_request_size: usize::try_from(max_request_size) + .map_err(|e| ServerError::custom("max_request_size", e))?, }) } @@ -182,7 +183,7 @@ impl UpdateEncoder { .set_desktop_size(size); } - fn rgba_pointer(ptr: RGBAPointer) -> Result { + fn rgba_pointer(ptr: RGBAPointer) -> ServerResult { let xor_mask = ptr.data; let hot_spot = Point16 { @@ -201,10 +202,13 @@ impl UpdateEncoder { xor_bpp: 32, color_pointer, }; - Ok(UpdateFragmenter::new(UpdateCode::NewPointer, encode_vec(&ptr)?)) + Ok(UpdateFragmenter::new( + UpdateCode::NewPointer, + encode_vec(&ptr).map_err(ServerError::encode)?, + )) } - fn color_pointer(ptr: ColorPointer) -> Result { + fn color_pointer(ptr: ColorPointer) -> ServerResult { let hot_spot = Point16 { x: ptr.hot_x, y: ptr.hot_y, @@ -217,24 +221,33 @@ impl UpdateEncoder { xor_mask: &ptr.xor_mask, and_mask: &ptr.and_mask, }; - Ok(UpdateFragmenter::new(UpdateCode::ColorPointer, encode_vec(&ptr)?)) + Ok(UpdateFragmenter::new( + UpdateCode::ColorPointer, + encode_vec(&ptr).map_err(ServerError::encode)?, + )) } - fn cached_pointer(cache_index: u16) -> Result { + fn cached_pointer(cache_index: u16) -> ServerResult { let ptr = CachedPointerAttribute { cache_index }; - Ok(UpdateFragmenter::new(UpdateCode::CachedPointer, encode_vec(&ptr)?)) + Ok(UpdateFragmenter::new( + UpdateCode::CachedPointer, + encode_vec(&ptr).map_err(ServerError::encode)?, + )) } - fn default_pointer() -> Result { + fn default_pointer() -> ServerResult { Ok(UpdateFragmenter::new(UpdateCode::DefaultPointer, vec![])) } - fn hide_pointer() -> Result { + fn hide_pointer() -> ServerResult { Ok(UpdateFragmenter::new(UpdateCode::HiddenPointer, vec![])) } - fn pointer_position(pos: PointerPositionAttribute) -> Result { - Ok(UpdateFragmenter::new(UpdateCode::PositionPointer, encode_vec(&pos)?)) + fn pointer_position(pos: PointerPositionAttribute) -> ServerResult { + Ok(UpdateFragmenter::new( + UpdateCode::PositionPointer, + encode_vec(&pos).map_err(ServerError::encode)?, + )) } fn bitmap_diffs(&mut self, bitmap: &BitmapUpdate) -> Vec { @@ -330,7 +343,7 @@ impl UpdateEncoder { } } - async fn bitmap(&mut self, bitmap: BitmapUpdate) -> Result { + async fn bitmap(&mut self, bitmap: BitmapUpdate) -> ServerResult { // Move the bitmap updater to satisfy spawn_blocking 'static requirement. // It is restored after the blocking operation completes. let mut updater = self.bitmap_updater.take().expect("bitmap updater always Some"); @@ -339,7 +352,8 @@ impl UpdateEncoder { let result = time_warn!("Encoding bitmap", 10, updater.handle(&bitmap)); (result, updater) }) - .await?; + .await + .map_err(|e| ServerError::custom("bitmap encoder task", e))?; self.bitmap_updater = Some(updater); @@ -367,7 +381,7 @@ pub(crate) struct EncoderIter<'a> { impl EncoderIter<'_> { #[cfg_attr(feature = "__bench", visibility::make(pub))] - pub(crate) async fn next(&mut self) -> Option> { + pub(crate) async fn next(&mut self) -> Option> { loop { let state = core::mem::take(&mut self.state); let encoder = &mut self.encoder; @@ -406,25 +420,55 @@ impl EncoderIter<'_> { let x = match u16::try_from(x) { Ok(x) => x, - Err(_) => return Some(Err(anyhow!("invalid `x`: out of range integral conversion"))), + Err(_) => { + return Some(Err(ServerError::reason( + "bitmap diff", + "invalid `x`: out of range integral conversion", + ))); + } }; let y = match u16::try_from(y) { Ok(y) => y, - Err(_) => return Some(Err(anyhow!("invalid `y`: out of range integral conversion"))), + Err(_) => { + return Some(Err(ServerError::reason( + "bitmap diff", + "invalid `y`: out of range integral conversion", + ))); + } }; let width = match u16::try_from(width) { Ok(width) => match NonZeroU16::new(width) { Some(width) => width, - None => return Some(Err(anyhow!("rectangle width cannot be zero"))), + None => { + return Some(Err(ServerError::reason( + "bitmap diff", + "rectangle width cannot be zero", + ))); + } }, - Err(_) => return Some(Err(anyhow!("invalid `width`: out of range integral conversion"))), + Err(_) => { + return Some(Err(ServerError::reason( + "bitmap diff", + "invalid `width`: out of range integral conversion", + ))); + } }; let height = match u16::try_from(height) { Ok(height) => match NonZeroU16::new(height) { Some(height) => height, - None => return Some(Err(anyhow!("rectangle height cannot be zero"))), + None => { + return Some(Err(ServerError::reason( + "bitmap diff", + "rectangle height cannot be zero", + ))); + } }, - Err(_) => return Some(Err(anyhow!("invalid `height`: out of range integral conversion"))), + Err(_) => { + return Some(Err(ServerError::reason( + "bitmap diff", + "invalid `height`: out of range integral conversion", + ))); + } }; let Some(sub) = bitmap.sub(x, y, width, height) else { @@ -460,7 +504,7 @@ enum BitmapUpdater { } impl BitmapUpdater { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { match self { Self::None(up) => up.handle(bitmap), Self::Bitmap(up) => up.handle(bitmap), @@ -482,14 +526,14 @@ impl BitmapUpdater { } trait BitmapUpdateHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result; + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult; } #[derive(Clone, Debug)] struct NoneHandler; impl BitmapUpdateHandler for NoneHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { let stride = usize::from(bitmap.format.bytes_per_pixel()) * usize::from(bitmap.width.get()); let mut data = Vec::with_capacity(stride * usize::from(bitmap.height.get())); for row in bitmap.data.chunks(bitmap.stride.get()).rev() { @@ -519,7 +563,7 @@ impl BitmapHandler { } impl BitmapUpdateHandler for BitmapHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { let mut buffer = vec![0; bitmap.data.len() * 2]; // TODO: estimate bitmap encoded size let len = loop { match self.bitmap.encode(bitmap, buffer.as_mut_slice()) { @@ -529,9 +573,9 @@ impl BitmapUpdateHandler for BitmapHandler { buffer.resize(buffer.len() * 2, 0); debug!("encoder buffer resized to: {}", buffer.len() * 2); } - _ => Err(e).context("bitmap encode error")?, + _ => return Err(ServerError::encode(e)), }, - BitmapEncodeError::Rle(e) => Err(e).context("bitmap RLE encode error")?, + BitmapEncodeError::Rle(e) => return Err(ServerError::custom("bitmap RLE encode error", e)), }, Ok(len) => break len, } @@ -564,7 +608,7 @@ impl RemoteFxHandler { } impl BitmapUpdateHandler for RemoteFxHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { let mut buffer = vec![0; bitmap.data.len()]; let len = loop { match self @@ -576,7 +620,7 @@ impl BitmapUpdateHandler for RemoteFxHandler { buffer.resize(buffer.len() * 2, 0); debug!("encoder buffer resized to: {}", buffer.len() * 2); } - _ => Err(e).context("RemoteFX encode error")?, + _ => return Err(ServerError::encode(e)), }, Ok(len) => break len, } @@ -601,7 +645,7 @@ impl QoiHandler { #[cfg(feature = "qoi")] impl BitmapUpdateHandler for QoiHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { let data = qoi_encode(bitmap)?; set_surface(bitmap, self.codec_id, &data) } @@ -622,23 +666,29 @@ impl fmt::Debug for QoizHandler { #[cfg(feature = "qoiz")] impl QoizHandler { - fn new(codec_id: u8) -> Result { + fn new(codec_id: u8) -> ServerResult { let mut zctxt = zstd_safe::CCtx::default(); zctxt .set_parameter(zstd_safe::CParameter::CompressionLevel(3)) .map_err(|code| { - anyhow!( - "failed to set zstd compression level: {}", - zstd_safe::get_error_name(code) + ServerError::reason( + "qoiz init", + format!( + "failed to set zstd compression level: {}", + zstd_safe::get_error_name(code) + ), ) })?; zctxt .set_parameter(zstd_safe::CParameter::EnableLongDistanceMatching(true)) .map_err(|code| { - anyhow!( - "failed to set zstd enable long distance matching: {}", - zstd_safe::get_error_name(code) + ServerError::reason( + "qoiz init", + format!( + "failed to set zstd enable long distance matching: {}", + zstd_safe::get_error_name(code) + ), ) })?; @@ -648,7 +698,7 @@ impl QoizHandler { #[cfg(feature = "qoiz")] impl BitmapUpdateHandler for QoizHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { let qoi = qoi_encode(bitmap)?; let mut inb = zstd_safe::InBuffer::around(&qoi); let mut data = vec![0; qoi.len()]; @@ -664,7 +714,12 @@ impl BitmapUpdateHandler for QoizHandler { &mut inb, zstd_safe::zstd_sys::ZSTD_EndDirective::ZSTD_e_flush, ) - .map_err(|code| anyhow!("failed to Zstd compress: {}", zstd_safe::get_error_name(code)))?; + .map_err(|code| { + ServerError::reason( + "qoiz compress", + format!("failed to Zstd compress: {}", zstd_safe::get_error_name(code)), + ) + })?; if res == 0 { break; } @@ -695,7 +750,7 @@ impl NsCodecHandler { #[cfg(feature = "nscodec")] impl BitmapUpdateHandler for NsCodecHandler { - fn handle(&mut self, bitmap: &BitmapUpdate) -> Result { + fn handle(&mut self, bitmap: &BitmapUpdate) -> ServerResult { let data = ironrdp_nscodec::encoder::encode( &bitmap.data, bitmap.width.get(), @@ -709,7 +764,7 @@ impl BitmapUpdateHandler for NsCodecHandler { } #[cfg(feature = "qoi")] -fn qoi_encode(bitmap: &BitmapUpdate) -> Result> { +fn qoi_encode(bitmap: &BitmapUpdate) -> ServerResult> { use ironrdp_graphics::image_processing::PixelFormat::*; // Map every 4-byte input — whether it nominally has an alpha byte or // an "X" filler — to the 3-channel-output `*x` variant of @@ -735,11 +790,12 @@ fn qoi_encode(bitmap: &BitmapUpdate) -> Result> { let enc = qoi::EncoderBuilder::new(&bitmap.data, bitmap.width.get().into(), bitmap.height.get().into()) .stride(bitmap.stride.get()) .raw_channels(raw_channels) - .build()?; - Ok(enc.encode_to_vec()?) + .build() + .map_err(|e| ServerError::custom("qoi encoder build", e))?; + enc.encode_to_vec().map_err(|e| ServerError::custom("qoi encode", e)) } -fn set_surface(bitmap: &BitmapUpdate, codec_id: u8, data: &[u8]) -> Result { +fn set_surface(bitmap: &BitmapUpdate, codec_id: u8, data: &[u8]) -> ServerResult { let destination = ExclusiveRectangle { left: bitmap.x, top: bitmap.y, @@ -759,5 +815,8 @@ fn set_surface(bitmap: &BitmapUpdate, codec_id: u8, data: &[u8]) -> Result) -> fmt::Result { + match self { + Self::Encode(_) => write!(f, "encode error"), + Self::Decode(_) => write!(f, "decode error"), + Self::Io(_) => write!(f, "I/O error"), + Self::Channel => write!(f, "channel error"), + Self::Unsupported => write!(f, "unsupported"), + Self::Reason(reason) => write!(f, "reason: {reason}"), + Self::General => write!(f, "general error"), + Self::Custom => write!(f, "custom error"), + } + } +} + +impl core::error::Error for ServerErrorKind { + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + match self { + Self::Encode(e) => Some(e), + Self::Decode(e) => Some(e), + Self::Io(e) => Some(e), + Self::Channel | Self::Unsupported | Self::Reason(_) | Self::General | Self::Custom => None, + } + } +} + +/// Server-side failure type. +/// +/// A typed alias of [`ironrdp_error::Error`] specialized to +/// [`ServerErrorKind`]. The wrapper adds a static `&'static str` context and +/// an optional opaque `source` to whichever kind of failure occurred. +pub type ServerError = ironrdp_error::Error; + +/// Convenience alias for `Result`. +pub type ServerResult = Result; + +/// Constructors for [`ServerError`] that match the shape of +/// [`ironrdp_connector::ConnectorErrorExt`]. +pub trait ServerErrorExt { + /// Build a [`ServerErrorKind::Encode`] error from an [`EncodeError`]. + fn encode(error: EncodeError) -> Self; + /// Build a [`ServerErrorKind::Decode`] error from a [`DecodeError`]. + fn decode(error: DecodeError) -> Self; + /// Build a [`ServerErrorKind::Io`] error with a static context and an + /// [`io::Error`] source. + fn io(context: &'static str, error: io::Error) -> Self; + /// Build a [`ServerErrorKind::Channel`] error with the channel name or + /// failure description carried in the context. + fn channel(context: &'static str) -> Self; + /// Build a [`ServerErrorKind::Unsupported`] error with the unsupported + /// feature named in the context. + fn unsupported(context: &'static str) -> Self; + /// Build a [`ServerErrorKind::General`] error with a static context. + fn general(context: &'static str) -> Self; + /// Build a [`ServerErrorKind::Reason`] error with a static context and a + /// runtime description. + fn reason(context: &'static str, reason: impl Into) -> Self; + /// Build a [`ServerErrorKind::Custom`] error with a static context and an + /// arbitrary source. + fn custom(context: &'static str, error: E) -> Self + where + E: core::error::Error + Sync + Send + 'static; +} + +impl ServerErrorExt for ServerError { + fn encode(error: EncodeError) -> Self { + Self::new("encode error", ServerErrorKind::Encode(error)) + } + + fn decode(error: DecodeError) -> Self { + Self::new("decode error", ServerErrorKind::Decode(error)) + } + + fn io(context: &'static str, error: io::Error) -> Self { + Self::new(context, ServerErrorKind::Io(error)) + } + + fn channel(context: &'static str) -> Self { + Self::new(context, ServerErrorKind::Channel) + } + + fn unsupported(context: &'static str) -> Self { + Self::new(context, ServerErrorKind::Unsupported) + } + + fn general(context: &'static str) -> Self { + Self::new(context, ServerErrorKind::General) + } + + fn reason(context: &'static str, reason: impl Into) -> Self { + Self::new(context, ServerErrorKind::Reason(reason.into())) + } + + fn custom(context: &'static str, error: E) -> Self + where + E: core::error::Error + Sync + Send + 'static, + { + Self::new(context, ServerErrorKind::Custom).with_source(error) + } +} + +/// Result-side helpers mirroring [`ironrdp_connector::ConnectorResultExt`]. +pub trait ServerResultExt { + /// Replace the `&'static str` context on any error in `Self`. + #[must_use] + fn with_context(self, context: &'static str) -> Self; + /// Attach a source to any error in `Self`. + #[must_use] + fn with_source(self, source: E) -> Self + where + E: core::error::Error + Sync + Send + 'static; +} + +impl ServerResultExt for ServerResult { + fn with_context(self, context: &'static str) -> Self { + self.map_err(|mut e| { + e.set_context(context); + e + }) + } + + fn with_source(self, source: E) -> Self + where + E: core::error::Error + Sync + Send + 'static, + { + self.map_err(|e| e.with_source(source)) + } +} + +/// Bridges anyhow errors at the public API boundary while the migration to +/// typed errors is staged. +/// +/// Internal call sites still use `anyhow::Result`; conversion happens here so +/// the public signatures can advertise [`ServerResult`] today without forcing +/// every internal site to convert in this PR. PR #2 in the staged migration +/// (see [#1209]) removes the remaining `anyhow` usage and this helper. +/// +/// [#1209]: https://github.com/Devolutions/IronRDP/issues/1209 +pub(crate) fn from_anyhow(error: anyhow::Error) -> ServerError { + ServerError::new("server error", ServerErrorKind::Custom).with_source(AnyhowError(error)) +} + +/// Newtype wrapper that gives [`anyhow::Error`] a `core::error::Error` impl +/// suitable for `ironrdp_error::Error::with_source`. +#[derive(Debug)] +struct AnyhowError(anyhow::Error); + +impl fmt::Display for AnyhowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:#}", self.0) + } +} + +impl core::error::Error for AnyhowError { + /// Forwards to the wrapped [`anyhow::Error`]'s cause chain so callers + /// traversing [`core::error::Error::source`] reach the original root + /// cause rather than stopping at this newtype. + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + self.0.source() + } +} diff --git a/crates/ironrdp-server/src/helper.rs b/crates/ironrdp-server/src/helper.rs index 6d4055715..e04a7ccfd 100644 --- a/crates/ironrdp-server/src/helper.rs +++ b/crates/ironrdp-server/src/helper.rs @@ -3,12 +3,13 @@ use std::io::BufReader; use std::path::Path; use std::sync::Arc; -use anyhow::Context as _; use rustls_pemfile::{certs, pkcs8_private_keys}; use tokio_rustls::rustls::pki_types::pem::PemObject as _; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio_rustls::{TlsAcceptor, rustls}; +use crate::error::{ServerError, ServerErrorExt as _, ServerResult}; + pub struct TlsIdentityCtx { pub certs: Vec>, pub priv_key: PrivateKeyDer<'static>, @@ -19,39 +20,45 @@ impl TlsIdentityCtx { /// A constructor to create a `TlsIdentityCtx` from the given certificate and key paths. /// /// The file format can be either PEM (if the file extension ends with .pem) or DER. - pub fn init_from_paths(cert_path: &Path, key_path: &Path) -> anyhow::Result { + pub fn init_from_paths(cert_path: &Path, key_path: &Path) -> ServerResult { let certs = if cert_path.extension().is_some_and(|ext| ext == "pem") { CertificateDer::pem_file_iter(cert_path) - .with_context(|| format!("reading server cert `{cert_path:?}`"))? + .map_err(|e| ServerError::custom("reading server cert", e))? .collect::, _>>() - .with_context(|| format!("collecting server cert `{cert_path:?}`"))? + .map_err(|e| ServerError::custom("collecting server cert", e))? } else { certs(&mut BufReader::new( - File::open(cert_path).with_context(|| format!("opening server cert `{cert_path:?}`"))?, + File::open(cert_path).map_err(|e| ServerError::io("opening server cert", e))?, )) .collect::, _>>() - .with_context(|| format!("collecting server cert `{cert_path:?}`"))? + .map_err(|e| ServerError::io("collecting server cert", e))? }; let priv_key = if key_path.extension().is_some_and(|ext| ext == "pem") { - PrivateKeyDer::from_pem_file(key_path).with_context(|| format!("reading server key `{key_path:?}`"))? + PrivateKeyDer::from_pem_file(key_path).map_err(|e| ServerError::custom("reading server key", e))? } else { - pkcs8_private_keys(&mut BufReader::new(File::open(key_path)?)) - .next() - .context("no private key")? - .map(PrivateKeyDer::from)? + pkcs8_private_keys(&mut BufReader::new( + File::open(key_path).map_err(|e| ServerError::io("opening server key", e))?, + )) + .next() + .ok_or_else(|| ServerError::reason("server key", "no private key"))? + .map(PrivateKeyDer::from) + .map_err(|e| ServerError::io("reading server key", e))? }; let pub_key = { use x509_cert::der::Decode as _; - let cert = certs.first().ok_or_else(|| std::io::Error::other("invalid cert"))?; - let cert = x509_cert::Certificate::from_der(cert).map_err(std::io::Error::other)?; + let cert = certs + .first() + .ok_or_else(|| ServerError::reason("server cert", "invalid cert"))?; + let cert = + x509_cert::Certificate::from_der(cert).map_err(|e| ServerError::custom("parsing server cert", e))?; cert.tbs_certificate .subject_public_key_info .subject_public_key .as_bytes() - .ok_or_else(|| std::io::Error::other("subject public key BIT STRING is not aligned"))? + .ok_or_else(|| ServerError::reason("server cert", "subject public key BIT STRING is not aligned"))? .to_owned() }; @@ -62,11 +69,11 @@ impl TlsIdentityCtx { }) } - pub fn make_acceptor(&self) -> anyhow::Result { + pub fn make_acceptor(&self) -> ServerResult { let mut server_config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(self.certs.clone(), self.priv_key.clone_key()) - .context("bad certificate/key")?; + .map_err(|e| ServerError::custom("bad certificate/key", e))?; // This adds support for the SSLKEYLOGFILE env variable (https://wiki.wireshark.org/TLS#using-the-pre-master-secret) server_config.key_log = Arc::new(rustls::KeyLogFile::new()); diff --git a/crates/ironrdp-server/src/lib.rs b/crates/ironrdp-server/src/lib.rs index 4ae6b7167..7f422fd61 100644 --- a/crates/ironrdp-server/src/lib.rs +++ b/crates/ironrdp-server/src/lib.rs @@ -13,6 +13,7 @@ mod clipboard; mod display; mod echo; mod encoder; +mod error; #[cfg(feature = "egfx")] mod gfx; mod handler; @@ -27,6 +28,7 @@ pub use display::{ RdpServerDisplayUpdates, }; pub use echo::{EchoDvcBridge, EchoRoundTripMeasurement, EchoServerHandle, EchoServerMessage}; +pub use error::{ServerError, ServerErrorExt, ServerErrorKind, ServerResult, ServerResultExt}; #[cfg(feature = "egfx")] pub use gfx::{EgfxServerMessage, GfxDvcBridge, GfxServerFactory, GfxServerHandle}; pub use handler::{KeyboardEvent, MouseEvent, RdpServerInputHandler}; diff --git a/crates/ironrdp-server/src/server.rs b/crates/ironrdp-server/src/server.rs index 0be2a0707..163689b65 100644 --- a/crates/ironrdp-server/src/server.rs +++ b/crates/ironrdp-server/src/server.rs @@ -5,7 +5,6 @@ use core::time::Duration; use std::rc::Rc; use std::sync::Arc; -use anyhow::{Context as _, Result, bail}; use ironrdp_acceptor::{Acceptor, AcceptorResult, BeginResult, DesktopSize}; use ironrdp_async::Framed; use ironrdp_cliprdr::CliprdrServer; @@ -39,6 +38,7 @@ use crate::clipboard::CliprdrServerFactory; use crate::display::{DisplayUpdate, RdpServerDisplay}; use crate::echo::{EchoDvcBridge, EchoServerHandle, EchoServerMessage, build_echo_request}; use crate::encoder::{UpdateEncoder, UpdateEncoderCodecs}; +use crate::error::{ServerError, ServerErrorExt as _, ServerResult, ServerResultExt as _, from_anyhow}; #[cfg(feature = "egfx")] use crate::gfx::{EgfxServerMessage, GfxServerFactory}; use crate::handler::RdpServerInputHandler; @@ -84,7 +84,7 @@ pub trait ConnectionHandler: Send { &mut self, peer: SocketAddr, duration: Duration, - error: Option<&anyhow::Error>, + error: Option<&ServerError>, ) -> PostConnectionAction { let _ = (peer, duration, error); PostConnectionAction::Continue @@ -671,7 +671,7 @@ impl RdpServer { acceptor.attach_static_channel(dvc); } - pub async fn run_connection(&mut self, stream: S) -> Result<()> + pub async fn run_connection(&mut self, stream: S) -> ServerResult<()> where S: AsyncRead + AsyncWrite + Send + Sync + Unpin, { @@ -697,7 +697,7 @@ impl RdpServer { let res = ironrdp_acceptor::accept_begin(framed, &mut acceptor) .await - .context("accept_begin failed")?; + .map_err(|e| ServerError::custom("accept_begin failed", e))?; match res { BeginResult::ShouldUpgrade(stream) => { @@ -731,7 +731,8 @@ impl RdpServer { pub_key.clone(), None, ) - .await?; + .await + .map_err(|e| ServerError::custom("accept_credssp", e))?; } let framed = self.accept_finalize(framed, acceptor).await?; @@ -750,12 +751,12 @@ impl RdpServer { Ok(()) } - pub async fn run(&mut self) -> Result<()> { + pub async fn run(&mut self) -> ServerResult<()> { // Create socket with control over options before binding. // Using TcpSocket instead of TcpListener::bind() allows setting // SO_REUSEADDR and IPv6 dual-stack mode. let socket = match self.opts.addr { - SocketAddr::V4(_) => TcpSocket::new_v4().context("create IPv4 socket")?, + SocketAddr::V4(_) => TcpSocket::new_v4().map_err(|e| ServerError::io("create IPv4 socket", e))?, SocketAddr::V6(_) => { // IPv6 socket: on Linux, dual-stack is the default // (net.ipv6.bindv6only=0), so IPv4 clients connect as @@ -763,7 +764,7 @@ impl RdpServer { // where IPV6_V6ONLY defaults to 1 (Windows, some BSDs), // only IPv6 clients will be accepted and a separate IPv4 // listener would be needed. - TcpSocket::new_v6().context("create IPv6 socket")? + TcpSocket::new_v6().map_err(|e| ServerError::io("create IPv6 socket", e))? } }; @@ -772,12 +773,18 @@ impl RdpServer { // on Windows SO_REUSEADDR has different semantics that allow a // second process to bind the same port, which is a security risk. #[cfg(unix)] - socket.set_reuseaddr(true).context("set SO_REUSEADDR")?; + socket + .set_reuseaddr(true) + .map_err(|e| ServerError::io("set SO_REUSEADDR", e))?; - socket.bind(self.opts.addr).context("bind listen address")?; + socket + .bind(self.opts.addr) + .map_err(|e| ServerError::io("bind listen address", e))?; - let listener = socket.listen(LISTENER_BACKLOG).context("start listener")?; - let local_addr = listener.local_addr()?; + let listener = socket + .listen(LISTENER_BACKLOG) + .map_err(|e| ServerError::io("start listener", e))?; + let local_addr = listener.local_addr().map_err(|e| ServerError::io("local_addr", e))?; debug!("Listening for connections on {local_addr}"); self.local_addr = Some(local_addr); @@ -862,10 +869,10 @@ impl RdpServer { writer: &mut impl FramedWrite, io_channel_id: u16, user_channel_id: u16, - ) -> Result { + ) -> ServerResult { match action { Action::FastPath => { - let input = decode(&bytes)?; + let input = decode(&bytes).map_err(ServerError::decode)?; self.handle_fastpath(input).await; } @@ -873,7 +880,7 @@ impl RdpServer { if self .handle_x224(writer, io_channel_id, user_channel_id, &bytes) .await - .context("X224 input error")? + .map_err(|e| ServerError::custom("X224 input error", e))? { debug!("Got disconnect request"); return Ok(RunState::Disconnect); @@ -891,7 +898,7 @@ impl RdpServer { io_channel_id: u16, buffer: &mut Vec, mut encoder: UpdateEncoder, - ) -> Result<(RunState, UpdateEncoder)> { + ) -> ServerResult<(RunState, UpdateEncoder)> { if let DisplayUpdate::Resize(desktop_size) = update { debug!(?desktop_size, "Display resize"); encoder.set_desktop_size(desktop_size); @@ -905,7 +912,7 @@ impl RdpServer { break; }; - let mut fragmenter = fragmenter.context("error while encoding")?; + let mut fragmenter = fragmenter?; if fragmenter.size_hint() > buffer.len() { buffer.resize(fragmenter.size_hint(), 0); } @@ -914,7 +921,7 @@ impl RdpServer { writer .write_all(&buffer[..len]) .await - .context("failed to write display update")?; + .map_err(|e| ServerError::custom("failed to write display update", e))?; } } @@ -927,7 +934,7 @@ impl RdpServer { writer: &mut impl FramedWrite, io_channel_id: u16, user_channel_id: u16, - ) -> Result { + ) -> ServerResult { // Avoid wave messages queuing up and causing extra delay. When a // batch carries more than `WAVE_KEEP` waves, drop the OLDEST ones // and keep the most recent — playing stale audio just bakes the @@ -978,12 +985,16 @@ impl RdpServer { continue; } } - .context("failed to send rdpsnd event")?; + .map_err(|e| ServerError::custom("failed to send rdpsnd event", e))?; let channel_id = self .get_channel_id_by_type::() - .context("SVC channel not found")?; - let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?; - writer.write_all(&data).await?; + .ok_or_else(|| ServerError::channel("SVC channel not found"))?; + let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id) + .map_err(ServerError::encode)?; + writer + .write_all(&data) + .await + .map_err(|e| ServerError::io("write_all", e))?; } ServerEvent::Clipboard(c) => { let Some(cliprdr) = self.get_svc_processor::() else { @@ -1001,12 +1012,16 @@ impl RdpServer { continue; } } - .context("failed to send clipboard event")?; + .map_err(|e| ServerError::custom("failed to send clipboard event", e))?; let channel_id = self .get_channel_id_by_type::() - .context("SVC channel not found")?; - let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?; - writer.write_all(&data).await?; + .ok_or_else(|| ServerError::channel("SVC channel not found"))?; + let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id) + .map_err(ServerError::encode)?; + writer + .write_all(&data) + .await + .map_err(|e| ServerError::io("write_all", e))?; } ServerEvent::Echo(msg) => match msg { EchoServerMessage::SendRequest { payload } => { @@ -1029,14 +1044,19 @@ impl RdpServer { let request = build_echo_request(payload)?; let messages = - dvc::encode_dvc_messages(echo_channel_id, vec![request], ChannelFlags::SHOW_PROTOCOL)?; + dvc::encode_dvc_messages(echo_channel_id, vec![request], ChannelFlags::SHOW_PROTOCOL) + .map_err(ServerError::encode)?; let drdynvc_channel_id = self .get_channel_id_by_type::() - .context("DRDYNVC channel not found")?; - - let data = server_encode_svc_messages(messages, drdynvc_channel_id, user_channel_id)?; - writer.write_all(&data).await?; + .ok_or_else(|| ServerError::channel("DRDYNVC channel not found"))?; + + let data = server_encode_svc_messages(messages, drdynvc_channel_id, user_channel_id) + .map_err(ServerError::encode)?; + writer + .write_all(&data) + .await + .map_err(|e| ServerError::io("write_all", e))?; } }, #[cfg(feature = "egfx")] @@ -1044,9 +1064,13 @@ impl RdpServer { EgfxServerMessage::SendMessages { messages } => { let drdynvc_channel_id = self .get_channel_id_by_type::() - .context("DRDYNVC channel not found")?; - let data = server_encode_svc_messages(messages, drdynvc_channel_id, user_channel_id)?; - writer.write_all(&data).await?; + .ok_or_else(|| ServerError::channel("DRDYNVC channel not found"))?; + let data = server_encode_svc_messages(messages, drdynvc_channel_id, user_channel_id) + .map_err(ServerError::encode)?; + writer + .write_all(&data) + .await + .map_err(|e| ServerError::io("write_all", e))?; } }, ServerEvent::AutoDetectRttRequest => { @@ -1058,7 +1082,10 @@ impl RdpServer { io_channel_id, user_channel_id, )?; - writer.write_all(&data).await?; + writer + .write_all(&data) + .await + .map_err(|e| ServerError::io("write_all", e))?; } } } @@ -1074,13 +1101,13 @@ impl RdpServer { io_channel_id: u16, user_channel_id: u16, mut encoder: UpdateEncoder, - ) -> Result + ) -> ServerResult where R: FramedRead, W: FramedWrite, { debug!("Starting client loop"); - let mut display_updates = self.display.lock().await.updates().await?; + let mut display_updates = self.display.lock().await.updates().await.map_err(from_anyhow)?; let mut writer = SharedWriter::new(writer); let mut display_writer = writer.clone(); let mut event_writer = writer.clone(); @@ -1090,7 +1117,10 @@ impl RdpServer { let this = Rc::clone(&s); let dispatch_pdu = async move { loop { - let (action, bytes) = reader.read_pdu().await?; + let (action, bytes) = reader + .read_pdu() + .await + .map_err(|e| ServerError::custom("read pdu", e))?; let mut this = this.lock().await; match this .dispatch_pdu(action, bytes, &mut writer, io_channel_id, user_channel_id) @@ -1176,7 +1206,7 @@ impl RdpServer { reader: &mut Framed, writer: &mut Framed, result: AcceptorResult, - ) -> Result + ) -> ServerResult where R: FramedRead, W: FramedWrite, @@ -1196,12 +1226,12 @@ impl RdpServer { Ok(CredentialDecision::Reject) => { warn!("Credential validation rejected"); send_access_denied(result.io_channel_id, result.user_channel_id, writer).await?; - bail!("credential validation rejected"); + return Err(ServerError::reason("credential validation", "rejected by validator")); } Err(e) => { error!(error = %e, "Credential validator backend error"); send_access_denied(result.io_channel_id, result.user_channel_id, writer).await?; - bail!("credential validation backend error"); + return Err(ServerError::custom("credential validation", e)); } } } else { @@ -1227,9 +1257,13 @@ impl RdpServer { let Some(channel_id) = channel_id else { continue; }; - let svc_responses = channel.start()?; - let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id)?; - writer.write_all(&response).await?; + let svc_responses = channel.start().map_err(|e| ServerError::custom("svc start", e))?; + let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id) + .map_err(ServerError::encode)?; + writer + .write_all(&response) + .await + .map_err(|e| ServerError::io("write svc response", e))?; } } @@ -1240,7 +1274,7 @@ impl RdpServer { CapabilitySet::General(c) => { let fastpath = c.extra_flags.contains(GeneralExtraFlags::FASTPATH_OUTPUT_SUPPORTED); if !fastpath { - bail!("Fastpath output not supported!"); + return Err(ServerError::unsupported("Fastpath output")); } } CapabilitySet::Bitmap(b) => { @@ -1320,12 +1354,12 @@ impl RdpServer { let desktop_size = self.display.lock().await.size().await; let encoder = UpdateEncoder::new(desktop_size, surface_flags, update_codecs, self.opts.max_request_size) - .context("failed to initialize update encoder")?; + .with_context("failed to initialize update encoder")?; let state = self .client_loop(reader, writer, result.io_channel_id, result.user_channel_id, encoder) .await - .context("client loop failure")?; + .with_context("client loop failure")?; Ok(state) } @@ -1336,11 +1370,11 @@ impl RdpServer { io_channel_id: u16, user_channel_id: u16, frames: Vec>, - ) -> Result<()> { + ) -> ServerResult<()> { for frame in frames { match Action::from_fp_output_header(frame[0]) { Ok(Action::FastPath) => { - let input = decode(&frame)?; + let input = decode(&frame).map_err(ServerError::decode)?; self.handle_fastpath(input).await; } @@ -1392,8 +1426,8 @@ impl RdpServer { } } - async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> Result { - let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref())?; + async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> ServerResult { + let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref()).map_err(ServerError::decode)?; match control.share_control_pdu { ShareControlPdu::Data(header) => match header.share_data_pdu { @@ -1463,8 +1497,8 @@ impl RdpServer { io_channel_id: u16, user_channel_id: u16, frame: &[u8], - ) -> Result { - let message = decode::>>(frame)?; + ) -> ServerResult { + let message = decode::>>(frame).map_err(ServerError::decode)?; match message.0 { mcs::McsMessage::SendDataRequest(data) => { debug!( @@ -1478,9 +1512,15 @@ impl RdpServer { } if let Some(svc) = self.static_channels.get_by_channel_id_mut(data.channel_id) { - let response_pdus = svc.process(&data.user_data)?; - let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id)?; - writer.write_all(&response).await?; + let response_pdus = svc + .process(&data.user_data) + .map_err(|e| ServerError::custom("svc process", e))?; + let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id) + .map_err(ServerError::encode)?; + writer + .write_all(&response) + .await + .map_err(|e| ServerError::io("write svc response", e))?; } else { warn!(channel_id = data.channel_id, "Unexpected channel received: ID",); } @@ -1533,14 +1573,18 @@ impl RdpServer { } } - async fn accept_finalize(&mut self, mut framed: TokioFramed, mut acceptor: Acceptor) -> Result> + async fn accept_finalize( + &mut self, + mut framed: TokioFramed, + mut acceptor: Acceptor, + ) -> ServerResult> where S: AsyncRead + AsyncWrite + Sync + Send + Unpin, { loop { let (new_framed, result) = ironrdp_acceptor::accept_finalize(framed, &mut acceptor) .await - .context("failed to accept client during finalize")?; + .map_err(|e| ServerError::custom("failed to accept client during finalize", e))?; let (mut reader, mut writer) = split_tokio_framed(new_framed); @@ -1557,7 +1601,8 @@ impl RdpServer { acceptor, core::mem::take(&mut self.static_channels), desktop_size, - )?; + ) + .map_err(|e| ServerError::custom("deactivation-reactivation acceptor", e))?; framed = unsplit_tokio_framed(reader, writer); continue; } @@ -1585,7 +1630,7 @@ fn encode_share_data_pdu( share_data_pdu: rdp::headers::ShareDataPdu, io_channel_id: u16, user_channel_id: u16, -) -> Result> { +) -> ServerResult> { let header = rdp::headers::ShareDataHeader { share_data_pdu, stream_priority: rdp::headers::StreamPriority::Medium, @@ -1597,34 +1642,33 @@ fn encode_share_data_pdu( pdu_source: user_channel_id, share_control_pdu: ShareControlPdu::Data(header), }; - let user_data = encode_vec(&pdu)?.into(); + let user_data = encode_vec(&pdu).map_err(ServerError::encode)?.into(); let mcs_pdu = SendDataIndication { initiator_id: user_channel_id, channel_id: io_channel_id, user_data, }; - Ok(encode_vec(&X224(mcs_pdu))?) + encode_vec(&X224(mcs_pdu)).map_err(ServerError::encode) } -async fn deactivate_all( - io_channel_id: u16, - user_channel_id: u16, - writer: &mut impl FramedWrite, -) -> Result<(), anyhow::Error> { +async fn deactivate_all(io_channel_id: u16, user_channel_id: u16, writer: &mut impl FramedWrite) -> ServerResult<()> { let pdu = ShareControlPdu::ServerDeactivateAll(ServerDeactivateAll); let pdu = rdp::headers::ShareControlHeader { share_id: 0, pdu_source: io_channel_id, share_control_pdu: pdu, }; - let user_data = encode_vec(&pdu)?.into(); + let user_data = encode_vec(&pdu).map_err(ServerError::encode)?.into(); let pdu = SendDataIndication { initiator_id: user_channel_id, channel_id: io_channel_id, user_data, }; - let msg = encode_vec(&X224(pdu))?; - writer.write_all(&msg).await?; + let msg = encode_vec(&X224(pdu)).map_err(ServerError::encode)?; + writer + .write_all(&msg) + .await + .map_err(|e| ServerError::io("write deactivate_all", e))?; Ok(()) } @@ -1636,18 +1680,21 @@ async fn send_access_denied( io_channel_id: u16, user_channel_id: u16, writer: &mut impl FramedWrite, -) -> Result<(), anyhow::Error> { +) -> ServerResult<()> { let info = ServerSetErrorInfoPdu(ErrorInfo::ProtocolIndependentCode( ProtocolIndependentCode::ServerDeniedConnection, )); - let user_data = encode_vec(&info)?.into(); + let user_data = encode_vec(&info).map_err(ServerError::encode)?.into(); let pdu = SendDataIndication { initiator_id: user_channel_id, channel_id: io_channel_id, user_data, }; - let msg = encode_vec(&X224(pdu))?; - writer.write_all(&msg).await?; + let msg = encode_vec(&X224(pdu)).map_err(ServerError::encode)?; + writer + .write_all(&msg) + .await + .map_err(|e| ServerError::io("write access_denied", e))?; Ok(()) } diff --git a/crates/ironrdp/examples/server.rs b/crates/ironrdp/examples/server.rs index 71db6d021..30a199b6a 100644 --- a/crates/ironrdp/examples/server.rs +++ b/crates/ironrdp/examples/server.rs @@ -437,5 +437,5 @@ async fn run( domain: None, })); - server.run().await + server.run().await.map_err(|e| anyhow::anyhow!(e)) }