Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
Arc, Mutex, RwLock,
},
};

Expand Down Expand Up @@ -44,19 +44,19 @@ pub(crate) struct ProxyServer {
current_id: Arc<AtomicU64>,
clients: Arc<Mutex<ClientMap>>,
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
http_channel: mpsc::UnboundedSender<Key>,
pub(crate) connected: Arc<AtomicBool>,
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
config: Arc<Mutex<Option<Configuration>>>,
cookie_key: Arc<RwLock<Option<Key>>>,
setup_in_progress: Arc<AtomicBool>,
}

impl ProxyServer {
#[must_use]
/// Create new `ProxyServer`.
pub(crate) fn new(http_channel: mpsc::UnboundedSender<Key>) -> Self {
pub(crate) fn new(cookie_key: Arc<RwLock<Option<Key>>>) -> Self {
Self {
http_channel,
cookie_key,
current_id: Arc::new(AtomicU64::new(1)),
clients: Arc::new(Mutex::new(HashMap::new())),
results: Arc::new(Mutex::new(HashMap::new())),
Expand Down Expand Up @@ -193,7 +193,7 @@ impl Clone for ProxyServer {
results: Arc::clone(&self.results),
connected: Arc::clone(&self.connected),
core_version: Arc::clone(&self.core_version),
http_channel: self.http_channel.clone(),
cookie_key: Arc::clone(&self.cookie_key),
config: Arc::clone(&self.config),
setup_in_progress: Arc::clone(&self.setup_in_progress),
}
Expand Down Expand Up @@ -252,10 +252,7 @@ impl proxy_server::Proxy for ProxyServer {
Key::generate()
}
};
self.http_channel.send(key).map_err(|err| {
error!("Failed to send private cookies key to HTTP server: {err:?}");
Status::internal("Failed to send private cookies key to HTTP server")
})?;
*self.cookie_key.write().unwrap() = Some(key);

let (tx, rx) = mpsc::unbounded_channel();
self.clients
Expand Down
59 changes: 44 additions & 15 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
fs::read_to_string,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::Path,
sync::{atomic::Ordering, Arc, LazyLock},
sync::{atomic::Ordering, Arc, LazyLock, RwLock},
time::Duration,
};

Expand All @@ -13,6 +13,7 @@ use axum::{
extract::{ConnectInfo, FromRef, State},
http::{header::HeaderValue, Request, Response, StatusCode},
middleware::{self, Next},
response::IntoResponse,
routing::{get, post},
serve, Json, Router,
};
Expand All @@ -22,7 +23,7 @@ use defguard_version::{server::DefguardVersionLayer, Version};
use serde::Serialize;
use tokio::{
net::TcpListener,
sync::{mpsc, oneshot, Mutex},
sync::{oneshot, Mutex},
task::JoinSet,
};
use tower_governor::{
Expand Down Expand Up @@ -63,7 +64,7 @@ pub(crate) struct AppState {
pub(crate) grpc_server: ProxyServer,
pub(crate) remote_mfa_sessions:
Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<String>>>>,
key: Key,
cookie_key: Arc<RwLock<Option<Key>>>,
url: Url,
}

Expand All @@ -85,7 +86,10 @@ impl AppState {

impl FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Self {
state.key.clone()
let maybe_key = state.cookie_key.read().unwrap().clone();
// We return the dummy key only to satisfy the `FromRef` trait, but it is never
// used in practice because of the `ensure_configured` middleware.
maybe_key.unwrap_or_else(|| Key::from(&[0; 64]))
}
}

Expand Down Expand Up @@ -173,18 +177,44 @@ async fn powered_by_header<B>(mut response: Response<B>) -> Response<B> {
response
}

/// Middleware that gates all HTTP endpoints except health checks until the proxy
/// is fully configured.
///
/// The proxy cannot safely handle requests that rely on encrypted cookies
/// (e.g. OpenID / MFA flows) until it receives the cookie encryption key from
/// the core. This key is provided asynchronously after the core connects.
///
/// Until the key is available, only health check endpoints are served and all
/// other requests return HTTP 503 (Service Unavailable). Once the key is set,
/// the middleware becomes a no-op and all routes are enabled.
async fn ensure_configured(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Response<Body> {
// Allow healthchecks even before core connects and gives us the cookie key.
let path = request.uri().path();
if matches!(path, "/api/v1/health" | "/api/v1/health-grpc") {
return next.run(request).await;
}

// Block all other requests until cookie key is configured.
if state.cookie_key.read().unwrap().is_none() {
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}

next.run(request).await
}

pub async fn run_server(config: Config) -> anyhow::Result<()> {
info!("Starting Defguard Proxy server");
debug!("Using config: {config:?}");

let mut tasks = JoinSet::new();

// Prepare the channel for gRPC -> http server communication.
// The channel sends private cookies key once core connects to gRPC.
let (tx, mut rx) = mpsc::unbounded_channel::<Key>();
let cookie_key = Default::default();

// connect to upstream gRPC server
let grpc_server = ProxyServer::new(tx);
let grpc_server = ProxyServer::new(Arc::clone(&cookie_key));

let server_clone = grpc_server.clone();

Expand Down Expand Up @@ -256,15 +286,10 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
}
});

// Wait for core to connect to gRPC and send private cookies encryption key.
let Some(key) = rx.recv().await else {
return Err(anyhow::Error::msg("http channel closed"));
};

// build application
debug!("Setting up API server");
let shared_state = AppState {
key,
cookie_key,
grpc_server,
remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
url: config.url.clone(),
Expand Down Expand Up @@ -324,6 +349,10 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
.route("/info", get(app_info)),
)
.fallback_service(get(handle_404))
.layer(middleware::from_fn_with_state(
shared_state.clone(),
ensure_configured,
))
.layer(middleware::map_response(powered_by_header))
.layer(middleware::from_fn_with_state(
shared_state.clone(),
Expand Down