diff --git a/src/grpc.rs b/src/grpc.rs index 5c20ae2..19d9f80 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -4,7 +4,7 @@ use std::{ net::SocketAddr, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, Mutex, + Arc, Mutex, RwLock, }, }; @@ -44,19 +44,19 @@ pub(crate) struct ProxyServer { current_id: Arc, clients: Arc>, results: Arc>>>, - http_channel: mpsc::UnboundedSender, pub(crate) connected: Arc, pub(crate) core_version: Arc>>, config: Arc>>, + cookie_key: Arc>>, setup_in_progress: Arc, } impl ProxyServer { #[must_use] /// Create new `ProxyServer`. - pub(crate) fn new(http_channel: mpsc::UnboundedSender) -> Self { + pub(crate) fn new(cookie_key: Arc>>) -> Self { Self { - http_channel, + cookie_key, current_id: Arc::new(AtomicU64::new(1)), clients: Arc::new(Mutex::new(HashMap::new())), results: Arc::new(Mutex::new(HashMap::new())), @@ -193,7 +193,7 @@ impl Clone for ProxyServer { results: Arc::clone(&self.results), connected: Arc::clone(&self.connected), core_version: Arc::clone(&self.core_version), - http_channel: self.http_channel.clone(), + cookie_key: Arc::clone(&self.cookie_key), config: Arc::clone(&self.config), setup_in_progress: Arc::clone(&self.setup_in_progress), } @@ -252,10 +252,7 @@ impl proxy_server::Proxy for ProxyServer { Key::generate() } }; - self.http_channel.send(key).map_err(|err| { - error!("Failed to send private cookies key to HTTP server: {err:?}"); - Status::internal("Failed to send private cookies key to HTTP server") - })?; + *self.cookie_key.write().unwrap() = Some(key); let (tx, rx) = mpsc::unbounded_channel(); self.clients diff --git a/src/http.rs b/src/http.rs index 8d165df..c9c5180 100644 --- a/src/http.rs +++ b/src/http.rs @@ -3,7 +3,7 @@ use std::{ fs::read_to_string, net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, - sync::{atomic::Ordering, Arc, LazyLock}, + sync::{atomic::Ordering, Arc, LazyLock, RwLock}, time::Duration, }; @@ -13,6 +13,7 @@ use axum::{ extract::{ConnectInfo, FromRef, State}, http::{header::HeaderValue, Request, Response, StatusCode}, middleware::{self, Next}, + response::IntoResponse, routing::{get, post}, serve, Json, Router, }; @@ -22,7 +23,7 @@ use defguard_version::{server::DefguardVersionLayer, Version}; use serde::Serialize; use tokio::{ net::TcpListener, - sync::{mpsc, oneshot, Mutex}, + sync::{oneshot, Mutex}, task::JoinSet, }; use tower_governor::{ @@ -63,7 +64,7 @@ pub(crate) struct AppState { pub(crate) grpc_server: ProxyServer, pub(crate) remote_mfa_sessions: Arc>>>, - key: Key, + cookie_key: Arc>>, url: Url, } @@ -85,7 +86,10 @@ impl AppState { impl FromRef for Key { fn from_ref(state: &AppState) -> Self { - state.key.clone() + let maybe_key = state.cookie_key.read().unwrap().clone(); + // We return the dummy key only to satisfy the `FromRef` trait, but it is never + // used in practice because of the `ensure_configured` middleware. + maybe_key.unwrap_or_else(|| Key::from(&[0; 64])) } } @@ -173,18 +177,44 @@ async fn powered_by_header(mut response: Response) -> Response { response } +/// Middleware that gates all HTTP endpoints except health checks until the proxy +/// is fully configured. +/// +/// The proxy cannot safely handle requests that rely on encrypted cookies +/// (e.g. OpenID / MFA flows) until it receives the cookie encryption key from +/// the core. This key is provided asynchronously after the core connects. +/// +/// Until the key is available, only health check endpoints are served and all +/// other requests return HTTP 503 (Service Unavailable). Once the key is set, +/// the middleware becomes a no-op and all routes are enabled. +async fn ensure_configured( + State(state): State, + request: Request, + next: Next, +) -> Response { + // Allow healthchecks even before core connects and gives us the cookie key. + let path = request.uri().path(); + if matches!(path, "/api/v1/health" | "/api/v1/health-grpc") { + return next.run(request).await; + } + + // Block all other requests until cookie key is configured. + if state.cookie_key.read().unwrap().is_none() { + return StatusCode::SERVICE_UNAVAILABLE.into_response(); + } + + next.run(request).await +} + pub async fn run_server(config: Config) -> anyhow::Result<()> { info!("Starting Defguard Proxy server"); debug!("Using config: {config:?}"); let mut tasks = JoinSet::new(); - - // Prepare the channel for gRPC -> http server communication. - // The channel sends private cookies key once core connects to gRPC. - let (tx, mut rx) = mpsc::unbounded_channel::(); + let cookie_key = Default::default(); // connect to upstream gRPC server - let grpc_server = ProxyServer::new(tx); + let grpc_server = ProxyServer::new(Arc::clone(&cookie_key)); let server_clone = grpc_server.clone(); @@ -256,15 +286,10 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { } }); - // Wait for core to connect to gRPC and send private cookies encryption key. - let Some(key) = rx.recv().await else { - return Err(anyhow::Error::msg("http channel closed")); - }; - // build application debug!("Setting up API server"); let shared_state = AppState { - key, + cookie_key, grpc_server, remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())), url: config.url.clone(), @@ -324,6 +349,10 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .route("/info", get(app_info)), ) .fallback_service(get(handle_404)) + .layer(middleware::from_fn_with_state( + shared_state.clone(), + ensure_configured, + )) .layer(middleware::map_response(powered_by_header)) .layer(middleware::from_fn_with_state( shared_state.clone(),