diff --git a/CHANGELOG.md b/CHANGELOG.md index ec83f55..54410bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed + +- Hot-reload config storage: `Arc>` → `Arc>` (lock-free reads, `Arc` snapshots instead of full `Config` clones) +- `Proxy::config_snapshot()` is now synchronous and returns `Arc` (was `async fn` returning owned `Config`) +- `Proxy::update_config()` is now synchronous (was `async fn`) + +### Fixed + +- Hot-reload now applies on every HTTP request, including subsequent requests on keep-alive connections +- Hot-reload on TLS listeners started via `start_with_addr` / `start_tls` now picks up routing changes without restart + +### Added + +- Dependency: `arc-swap` +- Integration test: config hot-reload over a single keep-alive connection + ## [0.4.0] - 2026-05-25 ### Added diff --git a/Cargo.lock b/Cargo.lock index 0fa6813..e74ee66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,6 +73,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "arc-swap" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207" +dependencies = [ + "rustversion", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -807,9 +816,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.42" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -1178,6 +1187,7 @@ name = "tiny-proxy" version = "0.4.0" dependencies = [ "anyhow", + "arc-swap", "bytes", "clap", "criterion", diff --git a/Cargo.toml b/Cargo.toml index deb4db9..961c8ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ thiserror = "2.0.18" bytes = "1.0" num_cpus = "1.16" uuid = { version = "1", features = ["v4"] } +arc-swap = "1" serde = { version = "1.0", features = ["derive"], optional = true } serde_json = { version = "1.0", optional = true } diff --git a/README.md b/README.md index f7223f4..b229383 100644 --- a/README.md +++ b/README.md @@ -163,8 +163,8 @@ async fn main() -> anyhow::Result<()> { #### Hot-Reload Configuration -Update configuration at runtime without restart. The proxy uses `Arc>` internally, -so routing and directive changes take effect immediately for new connections. +Update configuration at runtime without restart. The proxy uses `Arc>` internally, +so routing and directive changes take effect on the next request (including keep-alive connections). > **TLS certificates**: cert/key files and `TlsAcceptor` are loaded when a listener starts. > Hot-reload updates site routing and directives, but **not** TLS certificates — to pick up @@ -173,9 +173,9 @@ so routing and directive changes take effect immediately for new connections. Example: ```rust +use arc_swap::ArcSwap; use tiny_proxy::{Config, Proxy}; use std::sync::Arc; -use tokio::sync::RwLock; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -192,12 +192,9 @@ async fn main() -> anyhow::Result<()> { } }); - // Update config at runtime — takes effect immediately + // Update config at runtime — takes effect on the next request let new_config = Config::from_file("new-config.conf")?; - { - let mut guard = config_handle.write().await; - *guard = new_config; - } + config_handle.store(Arc::new(new_config)); handle.await?; Ok(()) @@ -208,7 +205,7 @@ Or use the built-in `update_config` method: ```rust let new_config = Config::from_file("updated-config.conf")?; -proxy.update_config(new_config).await; +proxy.update_config(new_config); ``` ## Configuration @@ -493,11 +490,11 @@ Enable TLS support — both **frontend TLS termination** (HTTPS listeners) and * Management API for runtime configuration: ```rust +use arc_swap::ArcSwap; use tiny_proxy::api; use std::sync::Arc; -use tokio::sync::RwLock; -let config = Arc::new(RwLock::new(Config::from_file("config.conf")?)); +let config = Arc::new(ArcSwap::from_pointee(Config::from_file("config.conf")?)); api::start_api_server("127.0.0.1:8081", config).await?; ``` @@ -517,11 +514,11 @@ See the [module documentation](https://docs.rs/tiny-proxy) for detailed API refe - `Config::from_file(path)` - Load configuration from file - `Config::from_str(content)` - Parse configuration from string - `Proxy::new(config)` - Create proxy instance -- `Proxy::from_shared(config)` - Create proxy from shared `Arc>` +- `Proxy::from_shared(config)` - Create proxy from shared `Arc>` - `Proxy::start(addr)` - Start proxy server -- `Proxy::shared_config()` - Get `Arc>` for external config updates -- `Proxy::config_snapshot()` - Read current configuration as owned value -- `Proxy::update_config(config)` - Update configuration at runtime (async) +- `Proxy::shared_config()` - Get `Arc>` for external config updates +- `Proxy::config_snapshot()` - Read current configuration as `Arc` +- `Proxy::update_config(config)` - Update configuration at runtime ## Testing diff --git a/examples/hot_reload.rs b/examples/hot_reload.rs index 1d77c1f..5aa6000 100644 --- a/examples/hot_reload.rs +++ b/examples/hot_reload.rs @@ -1,108 +1,65 @@ //! Example of hot-reloading configuration without restarting the proxy //! -//! This example demonstrates how to: -//! - Load configuration from a file -//! - Create a Proxy instance -//! - Start the proxy server in the background -//! - Monitor the configuration file for changes -//! - Hot-reload the configuration when the file changes -//! //! Run with: //! ```bash //! cargo run --example hot_reload //! ``` //! -//! Then edit file.conf while the proxy is running to see hot-reload in action. +//! Then edit `file.conf` while the proxy is running. +use arc_swap::ArcSwap; +use std::sync::Arc; use tiny_proxy::{Config, Proxy}; use tokio::time::{sleep, Duration}; -use tracing::{error, info, warn}; +use tracing::{error, info}; use tracing_subscriber::{fmt, EnvFilter}; #[tokio::main] async fn main() -> anyhow::Result<()> { - // Initialize logging fmt() .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) .init(); - info!("Starting tiny-proxy hot-reload example"); - let config_path = "file.conf"; - - // Load initial configuration from file let config = Config::from_file(config_path)?; info!( "Loaded initial configuration for {} site(s)", config.sites.len() ); - // Create proxy instance - let proxy = std::sync::Arc::new(Proxy::new(config)); + let shared = Arc::new(ArcSwap::from_pointee(config)); + let proxy = Proxy::from_shared(shared.clone()); - // Spawn the proxy in a background task let _proxy_handle = tokio::spawn(async move { if let Err(e) = proxy.start("127.0.0.1:8080").await { eprintln!("Proxy error: {}", e); } }); - info!("Proxy started in background on http://127.0.0.1:8080"); + info!("Proxy started on http://127.0.0.1:8080"); info!("Monitoring {} for changes...", config_path); - info!("Edit the configuration file to see hot-reload in action"); - // Track the last modification time of the config file let mut last_modified = tokio::fs::metadata(config_path).await?.modified()?; - // Monitor for configuration changes loop { sleep(Duration::from_secs(2)).await; - // Check if file has been modified - match tokio::fs::metadata(config_path).await { - Ok(metadata) => { - if let Ok(modified) = metadata.modified() { - if modified != last_modified { - info!("Configuration file changed, reloading..."); - - // Try to load the new configuration - match Config::from_file(config_path) { - Ok(new_config) => { - info!( - "Successfully loaded new configuration with {} site(s)", - new_config.sites.len() - ); - - // Note: We can't update the running proxy directly from here - // because it's in a separate task. In a real application, you would: - // 1. Use Arc> to share the proxy between tasks - // 2. Or use a channel to send the new config to the proxy task - // 3. Or implement a shared configuration store with Arc> - - warn!("Note: This example demonstrates hot-reload detection."); - warn!("To actually update the running proxy, you need to share"); - warn!("the proxy instance with Arc> or similar."); - - // For demonstration purposes, we just show the detection - let site_count = new_config.sites.len(); - info!( - "New configuration would be applied with {} site(s)", - site_count - ); + let metadata = tokio::fs::metadata(config_path).await?; + let modified = metadata.modified()?; + if modified == last_modified { + continue; + } - // Update last_modified timestamp - last_modified = modified; - } - Err(e) => { - error!("Failed to reload configuration: {}", e); - warn!("Continuing with current configuration"); - } - } - } - } + info!("Configuration file changed, reloading..."); + match Config::from_file(config_path) { + Ok(new_config) => { + let sites_count = new_config.sites.len(); + shared.store(Arc::new(new_config)); + info!("Configuration updated ({} sites)", sites_count); + last_modified = modified; } Err(e) => { - error!("Failed to read file metadata: {}", e); + error!("Failed to reload configuration: {}", e); } } } diff --git a/src/api/endpoints.rs b/src/api/endpoints.rs index db15ce8..e6c8e9e 100644 --- a/src/api/endpoints.rs +++ b/src/api/endpoints.rs @@ -1,6 +1,7 @@ //! API endpoints for proxy management use anyhow::Result; +use arc_swap::ArcSwap; use bytes::Bytes; use http_body::Body; use http_body_util::{BodyExt, Full}; @@ -13,12 +14,12 @@ use crate::config::Config; /// Handle GET /config pub async fn handle_get_config( _req: Request, - config: Arc>, + config: Arc>, ) -> Result>> where B: Body, { - let config = config.read().await; + let config = config.load_full(); let json = serde_json::to_string_pretty(&*config) .unwrap_or_else(|_| r#"{"error": "Failed to serialize config"}"#.to_string()); @@ -57,7 +58,7 @@ where /// ``` pub async fn handle_post_config( req: Request, - config: Arc>, + config: Arc>, ) -> Result>> where B: Body, @@ -121,9 +122,8 @@ where // Atomically replace the configuration { - let mut guard = config.write().await; let sites_count = new_config.sites.len(); - *guard = new_config; + config.store(Arc::new(new_config)); info!( "POST /config - Configuration updated successfully ({} sites)", sites_count @@ -182,7 +182,7 @@ mod tests { #[tokio::test] async fn test_handle_get_config() { - let config = Arc::new(tokio::sync::RwLock::new(Config { + let config = Arc::new(ArcSwap::from_pointee(Config { sites: HashMap::new(), })); @@ -194,7 +194,7 @@ mod tests { #[tokio::test] async fn test_handle_post_config_valid_json() { - let config = Arc::new(tokio::sync::RwLock::new(Config { + let config = Arc::new(ArcSwap::from_pointee(Config { sites: HashMap::new(), })); @@ -219,14 +219,14 @@ mod tests { assert_eq!(response.status(), 200); // Verify config was actually updated - let guard = config.read().await; + let guard = config.load_full(); assert_eq!(guard.sites.len(), 1); assert!(guard.sites.contains_key("localhost:8080")); } #[tokio::test] async fn test_handle_post_config_invalid_json() { - let config = Arc::new(tokio::sync::RwLock::new(Config { + let config = Arc::new(ArcSwap::from_pointee(Config { sites: HashMap::new(), })); @@ -240,13 +240,13 @@ mod tests { assert_eq!(response.status(), 400); // Verify config was NOT updated - let guard = config.read().await; + let guard = config.load_full(); assert_eq!(guard.sites.len(), 0); } #[tokio::test] async fn test_handle_post_config_empty_body() { - let config = Arc::new(tokio::sync::RwLock::new(Config { + let config = Arc::new(ArcSwap::from_pointee(Config { sites: HashMap::new(), })); diff --git a/src/api/mod.rs b/src/api/mod.rs index 6a27b50..a9eba98 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,11 +9,11 @@ //! use tiny_proxy::api; //! use tiny_proxy::Config; //! use std::sync::Arc; -//! use tokio::sync::RwLock; +//! use arc_swap::ArcSwap; //! //! # #[tokio::main] //! # async fn main() -> anyhow::Result<()> { -//! let config = Arc::new(RwLock::new(Config::from_file("config.conf")?)); +//! let config = Arc::new(ArcSwap::from_pointee(Config::from_file("config.conf")?)); //! //! // Start the management API server //! api::start_api_server("127.0.0.1:8081", config).await?; diff --git a/src/api/server.rs b/src/api/server.rs index 1299a85..42fce5d 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -3,6 +3,7 @@ //! This module provides a REST API for managing the proxy configuration, //! including viewing and updating configuration settings. +use arc_swap::ArcSwap; use http_body_util::Full; use hyper::body::Incoming; use hyper::server::conn::http1; @@ -29,21 +30,22 @@ use crate::error::Result; /// # Arguments /// /// * `addr` - Address to listen on (e.g., "127.0.0.1:8081") -/// * `config` - Shared configuration wrapped in Arc> +/// * `config` - Shared configuration wrapped in `Arc>` /// /// # Example /// /// ```no_run /// # use tiny_proxy::{Config, api}; +/// # use arc_swap::ArcSwap; /// # use std::sync::Arc; /// # #[tokio::main] /// # async fn main() -> anyhow::Result<()> { -/// let config = Arc::new(tokio::sync::RwLock::new(Config::from_file("config.conf")?)); +/// let config = Arc::new(ArcSwap::from_pointee(Config::from_file("config.conf")?)); /// api::server::start_api_server("127.0.0.1:8081", config).await?; /// # Ok(()) /// # } /// ``` -pub async fn start_api_server(addr: &str, config: Arc>) -> Result<()> { +pub async fn start_api_server(addr: &str, config: Arc>) -> Result<()> { let addr: SocketAddr = addr.parse()?; start_api_server_with_addr(addr, config).await } @@ -55,10 +57,10 @@ pub async fn start_api_server(addr: &str, config: Arc> +/// * `config` - Shared configuration wrapped in `Arc>` pub async fn start_api_server_with_addr( addr: SocketAddr, - config: Arc>, + config: Arc>, ) -> Result<()> { let listener = TcpListener::bind(&addr).await?; @@ -87,7 +89,7 @@ pub async fn start_api_server_with_addr( /// Routes requests to appropriate endpoints based on method and path. async fn handle_api_request( req: Request, - config: Arc>, + config: Arc>, ) -> anyhow::Result>> { // TODO: Add authentication middleware if needed // let req = middleware::auth_middleware(req, api_key).await?; diff --git a/src/main.rs b/src/main.rs index 1b79189..4d00543 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,10 @@ use tiny_proxy::config::Config; use std::sync::Arc; #[cfg(feature = "api")] -use tokio::sync::{broadcast, RwLock}; +use arc_swap::ArcSwap; + +#[cfg(feature = "api")] +use tokio::sync::broadcast; #[cfg(feature = "api")] use tiny_proxy::start_api_server; @@ -96,7 +99,7 @@ async fn run_proxy_only(cli: Cli, config: Config) -> Result<(), anyhow::Error> { #[cfg(feature = "api")] async fn run_with_api(cli: Cli, config: Config) -> Result<(), anyhow::Error> { // Create shared configuration - let shared_config = Arc::new(RwLock::new(config.clone())); + let shared_config = Arc::new(ArcSwap::from_pointee(config.clone())); // Create shutdown channel let (shutdown_tx, _) = broadcast::channel::<()>(1); @@ -206,7 +209,7 @@ async fn run_with_api(cli: Cli, config: Config) -> Result<(), anyhow::Error> { async fn run_proxy_server( addr: Option, max_concurrency: usize, - shared_config: Arc>, + shared_config: Arc>, mut shutdown_rx: broadcast::Receiver<()>, ) -> Result<(), anyhow::Error> { // Create proxy from the shared config handle — any updates to @@ -247,7 +250,7 @@ async fn run_proxy_server( #[cfg(feature = "api")] async fn run_api_server( addr: String, - shared_config: Arc>, + shared_config: Arc>, mut shutdown_rx: broadcast::Receiver<()>, ) -> Result<(), anyhow::Error> { tokio::select! { diff --git a/src/proxy/handler.rs b/src/proxy/handler.rs index 0f5fc95..bffadf4 100644 --- a/src/proxy/handler.rs +++ b/src/proxy/handler.rs @@ -147,35 +147,28 @@ pub async fn proxy( // Generate or reuse request ID let initial_request_id = ensure_request_id(&mut req); - // Extract request info before processing - #[cfg(feature = "logging")] - let method = req.method().clone().to_string(); - let path = req.uri().path().to_string(); - let host = req - .headers() - .get(hyper::header::HOST) - .and_then(|h| h.to_str().ok()) - .unwrap_or("localhost") - .to_string(); - #[cfg(feature = "logging")] let span = info_span!("request", req_id = %initial_request_id); - #[allow(unused_variables)] let future = async move { + let path = req.uri().path().to_string(); + let host = req + .headers() + .get(hyper::header::HOST) + .and_then(|h| h.to_str().ok()) + .unwrap_or("localhost"); + #[cfg(feature = "logging")] let mut log_guard = AccessLogGuard::new( initial_request_id.clone(), remote_addr, - method, + req.method().to_string(), path.clone(), - host.clone(), + host.to_string(), ); // Find site configuration by host - // Browsers send Host: example.com (no port) for default ports, - // but config keys may be "example.com:443". Try both. - let site_config = match find_site(&config, &host, is_tls) { + let site_config = match find_site(&config, host, is_tls) { Some(config) => config, None => { error!("No configuration found for host: {}", host); @@ -469,18 +462,19 @@ pub fn match_pattern(pattern: &str, path: &str) -> Option { /// /// But config keys include the port: `"example.com:443"`, `"example.com:80"`. /// -/// This function tries multiple lookup strategies: -/// 1. Exact match: `host` as-is -/// 2. If `host` has no port → try `host:` based on `is_tls` -/// 3. If `host` has a port → also try just the hostname (in case config has no port) +/// Look up a site by Host header value. +/// +/// Tries, in order: +/// 1. Exact match on the raw host string. +/// 2. If host has no port → append default port (443 for TLS, 80 for HTTP) and retry. +/// 3. For TLS, match by SNI hostname if exactly one site matches. +/// 4. If host has a port → strip it and try the bare hostname. /// /// # Limitations /// -/// For non-default TLS ports (e.g., 8443), browsers always include the port -/// in the `Host` header (`Host: example.com:8443`), so strategy 1 (exact match) -/// works fine. The fallback in strategy 2 only tries ports 443 (TLS) and 80 (HTTP). -/// This means a non-browser client sending `Host: example.com` without a port to -/// a TLS listener on :8443 will get a 404 — this is a protocol violation by the client. +/// For non-default TLS ports (e.g., 8443), browsers include the port in `Host` +/// (`Host: example.com:8443`), so exact match works. A client sending `Host: example.com` +/// without a port to a TLS listener on :8443 gets 404 — that violates normal HTTP usage. pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> { // 1. Exact match if let Some(site) = config.sites.get(host) { @@ -525,11 +519,12 @@ pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a let hostname = if host.starts_with('[') { // IPv6 [::1]:port → ::1 let end = host.find(']').unwrap_or(host.len()); - host[1..end].to_string() + &host[1..end] } else { - host.rsplit(':').next_back().unwrap_or(host).to_string() + // example.com:443 → example.com + host.rsplit_once(':').map(|(name, _)| name).unwrap_or(host) }; - if let Some(site) = config.sites.get(&hostname) { + if let Some(site) = config.sites.get(hostname) { return Some(site); } } diff --git a/src/proxy/proxy.rs b/src/proxy/proxy.rs index 1d6cd6b..e22c57b 100644 --- a/src/proxy/proxy.rs +++ b/src/proxy/proxy.rs @@ -1,3 +1,4 @@ +use arc_swap::ArcSwap; use hyper::body::Incoming; use hyper::service::service_fn; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; @@ -10,7 +11,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; -use tokio::sync::{RwLock, Semaphore}; +use tokio::sync::Semaphore; use tracing::{error, info, warn}; #[cfg(feature = "tls")] @@ -22,7 +23,7 @@ use crate::proxy::handler::proxy; /// HTTP Proxy server that can be embedded into other applications /// /// This struct encapsulates the proxy state and allows programmatic control -/// over the proxy lifecycle. Configuration is stored in an `Arc>` +/// over the proxy lifecycle. Configuration is stored in an `Arc>` /// so it can be hot-reloaded at runtime (e.g. via the API server). /// /// # Example @@ -44,8 +45,7 @@ use crate::proxy::handler::proxy; /// ```no_run /// use tiny_proxy::{Config, Proxy}; /// use std::sync::Arc; -/// use tokio::sync::RwLock; -/// +/// /// /// #[tokio::main] /// async fn main() -> anyhow::Result<()> { /// let config = Config::from_file("config.conf")?; @@ -63,17 +63,14 @@ use crate::proxy::handler::proxy; /// /// // Later, update config at runtime /// let new_config = Config::from_file("updated-config.conf")?; -/// { -/// let mut guard = config_handle.write().await; -/// *guard = new_config; -/// } +/// config_handle.store(Arc::new(new_config)); /// /// handle.await?; /// Ok(()) /// } /// ``` pub struct Proxy { - config: Arc>, + config: Arc>, client: Client, Incoming>, max_concurrency: usize, semaphore: Arc, @@ -82,7 +79,7 @@ pub struct Proxy { impl Proxy { /// Create a new proxy instance with the given configuration /// - /// The configuration is internally wrapped in `Arc>` + /// The configuration is internally wrapped in `Arc>` /// so it can be shared with an API server for hot-reload. /// /// # Arguments @@ -122,7 +119,7 @@ impl Proxy { ); Self { - config: Arc::new(RwLock::new(config)), + config: Arc::new(ArcSwap::from_pointee(config)), client, max_concurrency, semaphore, @@ -131,13 +128,13 @@ impl Proxy { /// Create a new proxy instance from an already shared configuration /// - /// Use this when you already have an `Arc>` that is + /// Use this when you already have an `Arc>` that is /// shared with an API server or other component. /// /// # Arguments /// - /// * `config` - Shared configuration wrapped in `Arc>` - pub fn from_shared(config: Arc>) -> Self { + /// * `config` - Shared configuration wrapped in `Arc>` + pub fn from_shared(config: Arc>) -> Self { let mut http = HttpConnector::new(); http.set_keepalive(Some(Duration::from_secs(60))); http.set_nodelay(true); @@ -214,7 +211,7 @@ impl Proxy { /// * `addr` - Parsed SocketAddr to listen on pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> { // Check if any site on this address has TLS configured - let config_snapshot = self.config.read().await.clone(); + let config_snapshot = self.config.load_full(); let tls_sites: Vec<(String, crate::config::TlsConfig)> = config_snapshot .sites .values() @@ -277,7 +274,7 @@ impl Proxy { /// # } /// ``` pub async fn start_all(&self) -> anyhow::Result<()> { - let config_snapshot = self.config.read().await.clone(); + let config_snapshot = self.config.load_full(); // Group sites by resolved listen socket (multiple hostnames may share one port) let mut socket_groups: HashMap> = @@ -332,9 +329,7 @@ impl Proxy { let client = client.clone(); let config = config.clone(); async move { - let config_guard = config.read().await; - let config_snapshot = Arc::new(config_guard.clone()); - drop(config_guard); + let config_snapshot = config.load_full(); proxy(req, client, config_snapshot, remote_addr, true).await } }) @@ -443,7 +438,7 @@ impl Proxy { async fn run_http_loop( addr: SocketAddr, client: Client, Incoming>, - config: Arc>, + config: Arc>, semaphore: Arc, max_concurrency: usize, ) -> anyhow::Result<()> { @@ -461,15 +456,13 @@ impl Proxy { Ok(permit) => { tokio::task::spawn(async move { let _permit = permit; + let service = service_fn(move |req| { let client = client.clone(); let config = config.clone(); - let config_clone = config.clone(); async move { - let config_guard = config_clone.read().await; - let config_snapshot = Arc::new(config_guard.clone()); - drop(config_guard); + let config_snapshot = config.load_full(); proxy(req, client, config_snapshot, remote_addr, false).await } }); @@ -511,10 +504,9 @@ impl Proxy { listen_tls(addr, acceptor, semaphore, move |req, remote_addr| { let client = client.clone(); let config = config.clone(); + async move { - let config_guard = config.read().await; - let config_snapshot = Arc::new(config_guard.clone()); - drop(config_guard); + let config_snapshot = config.load_full(); proxy(req, client, config_snapshot, remote_addr, true).await } }) @@ -523,27 +515,27 @@ impl Proxy { /// Get a reference to the shared configuration handle /// - /// This returns a clone of the `Arc>`, allowing + /// This returns a clone of the `Arc>`, allowing /// external code (e.g. an API server) to read and update the /// configuration at runtime. /// /// # Returns /// - /// A cloned `Arc>` - pub fn shared_config(&self) -> Arc> { + /// A cloned `Arc>` + pub fn shared_config(&self) -> Arc> { self.config.clone() } /// Get a snapshot of the current configuration /// - /// Reads the current configuration and returns an owned clone. - /// This is useful for inspecting config without holding a lock. + /// Returns the current config as `Arc`. + /// The Arc can be shared cheaply (no cloning of Config internals). /// /// # Returns /// - /// A cloned `Config` - pub async fn config_snapshot(&self) -> Config { - self.config.read().await.clone() + /// An `Arc` snapshot + pub fn config_snapshot(&self) -> Arc { + self.config.load_full() } /// Get current concurrency limit @@ -583,10 +575,9 @@ impl Proxy { /// # Arguments /// /// * `config` - New configuration to use - pub async fn update_config(&self, config: Config) { - let mut guard = self.config.write().await; + pub fn update_config(&self, config: Config) { info!("Configuration updated ({} sites)", config.sites.len()); - *guard = config; + self.config.store(Arc::new(config)); } } @@ -640,9 +631,7 @@ mod tests { sites: HashMap::new(), }; let proxy = Proxy::new(config); - // Can't check sites len synchronously anymore, use snapshot - let rt = tokio::runtime::Runtime::new().unwrap(); - let snapshot = rt.block_on(proxy.config_snapshot()); + let snapshot = proxy.config_snapshot(); assert_eq!(snapshot.sites.len(), 0); } @@ -661,7 +650,7 @@ mod tests { ); let proxy = Proxy::new(config); - let snapshot = proxy.config_snapshot().await; + let snapshot = proxy.config_snapshot(); assert_eq!(snapshot.sites.len(), 1); assert!(snapshot.sites.contains_key("localhost:8080")); } @@ -672,7 +661,7 @@ mod tests { sites: HashMap::new(), }; let proxy = Proxy::new(config1); - let snapshot = proxy.config_snapshot().await; + let snapshot = proxy.config_snapshot(); assert_eq!(snapshot.sites.len(), 0); let mut config2 = Config { @@ -687,8 +676,8 @@ mod tests { }, ); - proxy.update_config(config2).await; - let snapshot = proxy.config_snapshot().await; + proxy.update_config(config2); + let snapshot = proxy.config_snapshot(); assert_eq!(snapshot.sites.len(), 1); assert!(snapshot.sites.contains_key("test.local")); } @@ -704,8 +693,8 @@ mod tests { // Update via the shared handle { - let mut guard = handle.write().await; - guard.sites.insert( + let mut m = HashMap::new(); + m.insert( "shared.local".to_string(), crate::config::SiteConfig { address: "shared.local".to_string(), @@ -713,10 +702,11 @@ mod tests { tls: None, }, ); + handle.store(Arc::new(Config { sites: m })); } // Verify the proxy sees the update - let snapshot = proxy.config_snapshot().await; + let snapshot = proxy.config_snapshot(); assert_eq!(snapshot.sites.len(), 1); assert!(snapshot.sites.contains_key("shared.local")); } @@ -726,23 +716,23 @@ mod tests { let config = Config { sites: HashMap::new(), }; - let shared = Arc::new(RwLock::new(config)); + let shared = Arc::new(ArcSwap::from_pointee(config)); let proxy = Proxy::from_shared(shared.clone()); - // Verify both point to the same config - let rt = tokio::runtime::Runtime::new().unwrap(); - { - let mut guard = rt.block_on(shared.write()); - guard.sites.insert( - "from-shared.local".to_string(), - crate::config::SiteConfig { - address: "from-shared.local".to_string(), - directives: vec![], - tls: None, - }, - ); - } - let snapshot = rt.block_on(proxy.config_snapshot()); + // Update config via shared handle + let mut m = HashMap::new(); + m.insert( + "from-shared.local".to_string(), + crate::config::SiteConfig { + address: "from-shared.local".to_string(), + directives: vec![], + tls: None, + }, + ); + shared.store(Arc::new(Config { sites: m })); + + // Verify the proxy sees the update + let snapshot = proxy.config_snapshot(); assert_eq!(snapshot.sites.len(), 1); assert!(snapshot.sites.contains_key("from-shared.local")); } diff --git a/tests/hot_reload_integration.rs b/tests/hot_reload_integration.rs new file mode 100644 index 0000000..3f7877d --- /dev/null +++ b/tests/hot_reload_integration.rs @@ -0,0 +1,81 @@ +//! Integration test: hot-reload applies on keep-alive connections (per-request config load). + +use std::collections::HashMap; +use std::sync::Arc; + +use arc_swap::ArcSwap; +use tiny_proxy::config::{Directive, SiteConfig}; +use tiny_proxy::{Config, Proxy}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +fn make_config(host_port: &str, body: &str) -> Config { + let mut sites = HashMap::new(); + sites.insert( + host_port.to_string(), + SiteConfig { + address: host_port.to_string(), + directives: vec![Directive::Respond { + status: 200, + body: body.to_string(), + }], + tls: None, + }, + ); + Config { sites } +} + +async fn get_random_port_addr() -> std::net::SocketAddr { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + addr +} + +/// Send one HTTP/1.1 request on an existing stream and return the response body. +async fn http_get_on_stream(stream: &mut TcpStream, host: &str, path: &str) -> String { + let request = format!("GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: keep-alive\r\n\r\n"); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await.unwrap(); + let response = String::from_utf8_lossy(&buf[..n]); + + // Body follows the header block (after \r\n\r\n) + response + .split("\r\n\r\n") + .nth(1) + .unwrap_or("") + .trim() + .to_string() +} + +#[tokio::test] +async fn test_hot_reload_on_keep_alive_connection() { + let addr = get_random_port_addr().await; + let host = format!("127.0.0.1:{}", addr.port()); + + let shared = Arc::new(ArcSwap::from_pointee(make_config(&host, "version-1"))); + let proxy = Proxy::from_shared(shared.clone()); + + let listen_addr = addr; + tokio::spawn(async move { + proxy.start_with_addr(listen_addr).await.unwrap(); + }); + + // Let the listener come up + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let mut stream = TcpStream::connect(addr).await.unwrap(); + + let body1 = http_get_on_stream(&mut stream, &host, "/").await; + assert_eq!(body1, "version-1"); + + shared.store(Arc::new(make_config(&host, "version-2"))); + + let body2 = http_get_on_stream(&mut stream, &host, "/").await; + assert_eq!( + body2, "version-2", + "second request on same keep-alive connection must use reloaded config" + ); +}