From 784b7015976da20516dd9c90d22ca65413c0da71 Mon Sep 17 00:00:00 2001 From: Artem Goncharov Date: Sat, 24 Jan 2026 14:50:17 +0100 Subject: [PATCH] feat: Validate domain is enabled for token When user domain is disabled the token must be invalid. For the trust trustor and trustee must be validated. --- src/identity/backend.rs | 2 +- src/identity/backend/sql.rs | 2 +- src/identity/backend/sql/user/get.rs | 17 ++- src/identity/mock.rs | 2 +- src/identity/mod.rs | 39 ++---- src/identity/types/provider_api.rs | 2 +- src/resource/backend.rs | 14 +++ src/resource/backend/error.rs | 1 + src/resource/backend/sql.rs | 152 ++++------------------- src/resource/backend/sql/domain/get.rs | 143 +++++++++++++++++++-- src/resource/backend/sql/domain/list.rs | 130 +++++++++++++++++++ src/resource/backend/sql/domain/mod.rs | 29 ++++- src/resource/backend/sql/project/list.rs | 2 +- src/resource/error.rs | 1 + src/resource/mock.rs | 12 ++ src/resource/mod.rs | 21 ++++ src/resource/types/domain.rs | 35 +++++- src/resource/types/project.rs | 2 +- src/resource/types/provider_api.rs | 36 ++++-- src/token/error.rs | 12 ++ src/token/types.rs | 56 +++++++-- 21 files changed, 508 insertions(+), 202 deletions(-) create mode 100644 src/resource/backend/sql/domain/list.rs diff --git a/src/identity/backend.rs b/src/identity/backend.rs index 2b53707a..cc366f6e 100644 --- a/src/identity/backend.rs +++ b/src/identity/backend.rs @@ -113,7 +113,7 @@ pub trait IdentityBackend: Send + Sync { &self, state: &ServiceState, user_id: &'a str, - ) -> Result, IdentityProviderError>; + ) -> Result; /// Find federated user by IDP and Unique ID. async fn find_federated_user<'a>( diff --git a/src/identity/backend/sql.rs b/src/identity/backend/sql.rs index aa353c70..0dd69989 100644 --- a/src/identity/backend/sql.rs +++ b/src/identity/backend/sql.rs @@ -169,7 +169,7 @@ impl IdentityBackend for SqlBackend { &self, state: &ServiceState, user_id: &'a str, - ) -> Result, IdentityProviderError> { + ) -> Result { Ok(user::get_user_domain_id(&state.db, user_id).await?) } diff --git a/src/identity/backend/sql/user/get.rs b/src/identity/backend/sql/user/get.rs index e2696d68..7b18a5b6 100644 --- a/src/identity/backend/sql/user/get.rs +++ b/src/identity/backend/sql/user/get.rs @@ -104,14 +104,17 @@ pub async fn get( pub async fn get_user_domain_id>( db: &DatabaseConnection, user_id: U, -) -> Result, IdentityDatabaseError> { - Ok(DbUser::find_by_id(user_id.as_ref()) +) -> Result { + DbUser::find_by_id(user_id.as_ref()) .select_only() .column(db_user::Column::DomainId) .into_tuple() .one(db) .await - .context("fetching domain_id of a user by ID")?) + .context("fetching domain_id of a user by ID")? + .ok_or(IdentityDatabaseError::UserNotFound( + user_id.as_ref().to_string(), + )) } #[cfg(test)] @@ -216,13 +219,7 @@ mod tests { ]]) .into_connection(); - assert_eq!( - get_user_domain_id(&db, "uid") - .await - .unwrap() - .expect("found"), - "did" - ); + assert_eq!(get_user_domain_id(&db, "uid").await.unwrap(), "did"); assert_eq!( db.into_transaction_log(), [Transaction::from_sql_and_values( diff --git a/src/identity/mock.rs b/src/identity/mock.rs index 7ebdf47f..5dc043a4 100644 --- a/src/identity/mock.rs +++ b/src/identity/mock.rs @@ -109,7 +109,7 @@ mock! { &self, state: &ServiceState, user_id: &'a str, - ) -> Result, IdentityProviderError>; + ) -> Result; async fn find_federated_user<'a>( &self, diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 679dc324..d2195f12 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -267,21 +267,19 @@ impl IdentityApi for IdentityProvider { &self, state: &ServiceState, user_id: &'a str, - ) -> Result, IdentityProviderError> { + ) -> Result { if self.caching { if let Some(domain_id) = self.user_id_domain_id_cache.read().await.get(user_id) { - return Ok(Some(domain_id.clone())); + return Ok(domain_id.clone()); } else { let domain_id = self .backend_driver .get_user_domain_id(state, user_id) .await?; - if let Some(did) = &domain_id { - self.user_id_domain_id_cache - .write() - .await - .insert(user_id.to_string(), did.clone()); - } + self.user_id_domain_id_cache + .write() + .await + .insert(user_id.to_string(), domain_id.clone()); return Ok(domain_id); } } else { @@ -498,28 +496,20 @@ mod tests { .expect_get_user_domain_id() .withf(|_, uid: &'_ str| uid == "uid") .times(2) // only 2 times - .returning(|_, _| Ok(Some("did".into()))); + .returning(|_, _| Ok("did".into())); backend .expect_get_user_domain_id() .withf(|_, uid: &'_ str| uid == "missing") - .returning(|_, _| Ok(None)); + .returning(|_, _| Err(IdentityProviderError::UserNotFound("missing".into()))); let mut provider = IdentityProvider::from_driver(backend); provider.caching = true; assert_eq!( - provider - .get_user_domain_id(&state, "uid") - .await - .unwrap() - .expect("domain_id should be there"), + provider.get_user_domain_id(&state, "uid").await.unwrap(), "did" ); assert_eq!( - provider - .get_user_domain_id(&state, "uid") - .await - .unwrap() - .expect("domain_id should be there"), + provider.get_user_domain_id(&state, "uid").await.unwrap(), "did", "second time data extracted from cache" ); @@ -527,16 +517,11 @@ mod tests { provider .get_user_domain_id(&state, "missing") .await - .unwrap() - .is_none() + .is_err() ); provider.caching = false; assert_eq!( - provider - .get_user_domain_id(&state, "uid") - .await - .unwrap() - .expect("domain_id should be there"), + provider.get_user_domain_id(&state, "uid").await.unwrap(), "did", "third time backend is again triggered causing total of 2 invocations" ); diff --git a/src/identity/types/provider_api.rs b/src/identity/types/provider_api.rs index 68b44729..1648578b 100644 --- a/src/identity/types/provider_api.rs +++ b/src/identity/types/provider_api.rs @@ -96,7 +96,7 @@ pub trait IdentityApi: Send + Sync { &self, state: &ServiceState, user_id: &'a str, - ) -> Result, IdentityProviderError>; + ) -> Result; async fn find_federated_user<'a>( &self, diff --git a/src/resource/backend.rs b/src/resource/backend.rs index 4d841b6f..e6f17a4c 100644 --- a/src/resource/backend.rs +++ b/src/resource/backend.rs @@ -25,6 +25,13 @@ use crate::resource::types::*; #[cfg_attr(test, mockall::automock)] #[async_trait] pub trait ResourceBackend: Send + Sync { + /// Get `enabled` field of the domain. + async fn get_domain_enabled<'a>( + &self, + state: &ServiceState, + domain_id: &'a str, + ) -> Result; + /// Create new project. async fn create_project( &self, @@ -68,6 +75,13 @@ pub trait ResourceBackend: Send + Sync { project_id: &'a str, ) -> Result>, ResourceProviderError>; + /// List domains. + async fn list_domains( + &self, + state: &ServiceState, + params: &DomainListParameters, + ) -> Result, ResourceProviderError>; + /// List projects. async fn list_projects( &self, diff --git a/src/resource/backend/error.rs b/src/resource/backend/error.rs index f82d5899..6f9adadd 100644 --- a/src/resource/backend/error.rs +++ b/src/resource/backend/error.rs @@ -25,6 +25,7 @@ pub enum ResourceDatabaseError { source: DatabaseError, }, + /// Domain not found. #[error("{0}")] DomainNotFound(String), diff --git a/src/resource/backend/sql.rs b/src/resource/backend/sql.rs index 41d6b19e..5ff10008 100644 --- a/src/resource/backend/sql.rs +++ b/src/resource/backend/sql.rs @@ -27,8 +27,18 @@ pub struct SqlBackend {} #[async_trait] impl ResourceBackend for SqlBackend { + /// Get `enabled` property of a domain. + #[tracing::instrument(level = "debug", skip(self, state))] + async fn get_domain_enabled<'a>( + &self, + state: &ServiceState, + domain_id: &'a str, + ) -> Result { + Ok(domain::get_domain_enabled(&state.db, domain_id).await?) + } + /// Create new project. - #[tracing::instrument(level = "info", skip(self, state))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_project( &self, state: &ServiceState, @@ -38,6 +48,7 @@ impl ResourceBackend for SqlBackend { } /// Get single domain by ID + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_domain<'a>( &self, state: &ServiceState, @@ -47,6 +58,7 @@ impl ResourceBackend for SqlBackend { } /// Get single domain by Name + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_domain_by_name<'a>( &self, state: &ServiceState, @@ -56,6 +68,7 @@ impl ResourceBackend for SqlBackend { } /// Get single project by ID + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_project<'a>( &self, state: &ServiceState, @@ -65,6 +78,7 @@ impl ResourceBackend for SqlBackend { } /// Get single project by Name and Domain ID + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_project_by_name<'a>( &self, state: &ServiceState, @@ -75,6 +89,7 @@ impl ResourceBackend for SqlBackend { } /// Get project parents + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_project_parents<'a>( &self, state: &ServiceState, @@ -83,7 +98,18 @@ impl ResourceBackend for SqlBackend { Ok(project::get_project_parents(&state.db, project_id).await?) } + /// List domains. + #[tracing::instrument(level = "info", skip(self, state))] + async fn list_domains( + &self, + state: &ServiceState, + params: &DomainListParameters, + ) -> Result, ResourceProviderError> { + Ok(domain::list(&state.db, params).await?) + } + /// List projects. + #[tracing::instrument(level = "info", skip(self, state))] async fn list_projects( &self, state: &ServiceState, @@ -92,127 +118,3 @@ impl ResourceBackend for SqlBackend { Ok(project::list(&state.db, params).await?) } } - -//#[cfg(test)] -//mod tests { -// #![allow(clippy::derivable_impls)] -// use chrono::Local; -// use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; -// -// use crate::db::entity::{local_user, password, user, user_option}; -// use crate::identity::Config; -// -// use super::*; -// -// fn get_user_mock(user_id: String) -> user::Model { -// user::Model { -// id: user_id.clone(), -// domain_id: "foo_domain".into(), -// enabled: Some(true), -// ..Default::default() -// } -// } -// -// fn get_local_user_with_password_mock( -// user_id: String, -// cnt_password: usize, -// ) -> Vec<(local_user::Model, password::Model)> { -// let lu = local_user::Model { -// user_id: user_id.clone(), -// domain_id: "foo_domain".into(), -// name: "Apple Cake".to_owned(), -// ..Default::default() -// }; -// let mut passwords: Vec = Vec::new(); -// for i in 0..cnt_password { -// passwords.push(password::Model { -// id: i as i32, -// local_user_id: 1, -// expires_at: None, -// self_service: false, -// password_hash: None, -// created_at: Local::now().naive_utc(), -// created_at_int: 12345, -// expires_at_int: None, -// }); -// } -// passwords -// .into_iter() -// .map(|x| (lu.clone(), x.clone())) -// .collect() -// } -// -// #[tokio::test] -// async fn test_get_user_local() { -// // Create MockDatabase with mock query results -// let db = MockDatabase::new(DatabaseBackend::Postgres) -// .append_query_results([ -// // First query result - select user itself -// vec![get_user_mock("1".into())], -// ]) -// .append_query_results([ -// //// Second query result - user options -// vec![user_option::Model { -// user_id: "1".into(), -// option_id: "1000".into(), -// option_value: Some("true".into()), -// }], -// ]) -// .append_query_results([ -// // Third query result - local user with passwords -// get_local_user_with_password_mock("1".into(), 1), -// ]) -// .into_connection(); -// let config = Config::default(); -// assert_eq!( -// get_user(&config, &db, "1".into()).await.unwrap().unwrap(), -// User { -// id: "1".into(), -// domain_id: "foo_domain".into(), -// name: "Apple Cake".to_owned(), -// enabled: true, -// options: UserOptions { -// ignore_change_password_upon_first_use: Some(true), -// ..Default::default() -// }, -// ..Default::default() -// } -// ); -// -// // Checking transaction log -// assert_eq!( -// db.into_transaction_log(), -// [ -// Transaction::from_sql_and_values( -// DatabaseBackend::Postgres, -// r#"SELECT "user"."id", "user"."extra", "user"."enabled", -// "user"."default_project_id", "user"."created_at", "user"."last_active_at", -// "user"."domain_id" FROM "user" WHERE "user"."id" = $1 LIMIT $2"#, -// ["1".into(), 1u64.into()] ), -// Transaction::from_sql_and_values( -// DatabaseBackend::Postgres, -// r#"SELECT "user_option"."user_id", -// "user_option"."option_id", "user_option"."option_value" FROM "user_option" -// INNER JOIN "user" ON "user"."id" = "user_option"."user_id" WHERE "user"."id" -// = $1"#, ["1".into()] -// ), -// Transaction::from_sql_and_values( -// DatabaseBackend::Postgres, -// r#"SELECT "local_user"."id" AS "A_id", -// "local_user"."user_id" AS "A_user_id", "local_user"."domain_id" AS -// "A_domain_id", "local_user"."name" AS "A_name", -// "local_user"."failed_auth_count" AS "A_failed_auth_count", -// "local_user"."failed_auth_at" AS "A_failed_auth_at", "password"."id" AS -// "B_id", "password"."local_user_id" AS "B_local_user_id", -// "password"."self_service" AS "B_self_service", "password"."created_at" AS -// "B_created_at", "password"."expires_at" AS "B_expires_at", -// "password"."password_hash" AS "B_password_hash", "password"."created_at_int" -// AS "B_created_at_int", "password"."expires_at_int" AS "B_expires_at_int" FROM -// "local_user" LEFT JOIN "password" ON "local_user"."id" = -// "password"."local_user_id" WHERE "local_user"."user_id" = $1 ORDER BY -// "local_user"."id" ASC"#, ["1".into()] -// ), -// ] -// ); -// } -//} diff --git a/src/resource/backend/sql/domain/get.rs b/src/resource/backend/sql/domain/get.rs index 05cdf490..41446cc6 100644 --- a/src/resource/backend/sql/domain/get.rs +++ b/src/resource/backend/sql/domain/get.rs @@ -21,18 +21,35 @@ use crate::error::DbContextExt; use crate::resource::backend::error::ResourceDatabaseError; use crate::resource::types::Domain; +pub async fn get_domain_enabled>( + db: &DatabaseConnection, + domain_id: I, +) -> Result { + DbProject::find_by_id(domain_id.as_ref()) + .filter(db_project::Column::IsDomain.eq(true)) + .select_only() + .column(db_project::Column::Enabled) + .into_tuple() + .one(db) + .await + .context("fetching domain `enabled` by id")? + .map(|x: Option| x.unwrap_or(true)) // python keystone defaults to `true` when unset + .ok_or(ResourceDatabaseError::DomainNotFound( + domain_id.as_ref().to_string(), + )) +} + pub async fn get_domain_by_id>( db: &DatabaseConnection, domain_id: I, ) -> Result, ResourceDatabaseError> { - let domain_select = - DbProject::find_by_id(domain_id.as_ref()).filter(db_project::Column::IsDomain.eq(true)); - - let domain_entry: Option = domain_select + DbProject::find_by_id(domain_id.as_ref()) + .filter(db_project::Column::IsDomain.eq(true)) .one(db) .await - .context("fetching domain by id")?; - domain_entry.map(TryInto::try_into).transpose() + .context("fetching domain by id")? + .map(TryInto::try_into) + .transpose() } pub async fn get_domain_by_name>( @@ -43,9 +60,117 @@ pub async fn get_domain_by_name>( .filter(db_project::Column::IsDomain.eq(true)) .filter(db_project::Column::Name.eq(domain_name.as_ref())); - let domain_entry: Option = domain_select + domain_select .one(db) .await - .context("fetching domain by name")?; - domain_entry.map(TryInto::try_into).transpose() + .context("fetching domain by name")? + .map(TryInto::try_into) + .transpose() +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, IntoMockRow, MockDatabase, Transaction}; + use std::collections::BTreeMap; + + use super::super::tests::*; + use super::*; + + #[tokio::test] + async fn test_get_by_name() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_domain_mock("1")]]) + .into_connection(); + + assert_eq!( + get_domain_by_name(&db, "name") + .await + .unwrap() + .expect("entry found"), + get_domain_mock("1").try_into().unwrap() + ); + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."id", "project"."name", "project"."extra", "project"."description", "project"."enabled", "project"."domain_id", "project"."parent_id", "project"."is_domain" FROM "project" WHERE "project"."is_domain" = $1 AND "project"."name" = $2 LIMIT $3"#, + [true.into(), "name".into(), 1u64.into()] + ),] + ); + } + + #[tokio::test] + async fn test_get_by_id() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_domain_mock("1")]]) + .into_connection(); + + assert_eq!( + get_domain_by_id(&db, "1") + .await + .unwrap() + .expect("entry found"), + get_domain_mock("1").try_into().unwrap() + ); + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."id", "project"."name", "project"."extra", "project"."description", "project"."enabled", "project"."domain_id", "project"."parent_id", "project"."is_domain" FROM "project" WHERE "project"."id" = $1 AND "project"."is_domain" = $2 LIMIT $3"#, + ["1".into(), true.into(), 1u64.into()] + ),] + ); + } + + #[tokio::test] + async fn test_get_domain_enabled() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([Vec::::new()]) + .append_query_results([vec![ + BTreeMap::from([("enabled", Into::::into(Some(true)))]).into_mock_row(), + ]]) + .append_query_results([vec![ + BTreeMap::from([("enabled", Into::::into(Some(false)))]).into_mock_row(), + ]]) + .append_query_results([vec![ + BTreeMap::from([("enabled", Into::::into(None::))]).into_mock_row(), + ]]) + .into_connection(); + + assert!(get_domain_enabled(&db, "missing").await.is_err()); + assert!(get_domain_enabled(&db, "id").await.unwrap(),); + assert!( + !get_domain_enabled(&db, "id").await.unwrap(), + "Some(false) should be disabled" + ); + assert!( + get_domain_enabled(&db, "id").await.unwrap(), + "enabled is empty in the db considered as active" + ); + assert_eq!( + db.into_transaction_log(), + [ + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."enabled" FROM "project" WHERE "project"."id" = $1 AND "project"."is_domain" = $2 LIMIT $3"#, + ["missing".into(), true.into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."enabled" FROM "project" WHERE "project"."id" = $1 AND "project"."is_domain" = $2 LIMIT $3"#, + ["id".into(), true.into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."enabled" FROM "project" WHERE "project"."id" = $1 AND "project"."is_domain" = $2 LIMIT $3"#, + ["id".into(), true.into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."enabled" FROM "project" WHERE "project"."id" = $1 AND "project"."is_domain" = $2 LIMIT $3"#, + ["id".into(), true.into(), 1u64.into()] + ), + ] + ); + } } diff --git a/src/resource/backend/sql/domain/list.rs b/src/resource/backend/sql/domain/list.rs new file mode 100644 index 00000000..62c9a833 --- /dev/null +++ b/src/resource/backend/sql/domain/list.rs @@ -0,0 +1,130 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use sea_orm::query::*; +use sea_orm::{Cursor, SelectModel}; + +use crate::db::entity::{prelude::Project as DbProject, project as db_project}; +use crate::error::DbContextExt; +use crate::resource::backend::error::ResourceDatabaseError; +use crate::resource::types::*; + +/// Prepare the paginated query for listing domains. +fn get_list_query( + params: &DomainListParameters, +) -> Result>, ResourceDatabaseError> { + let mut select = DbProject::find().filter(db_project::Column::IsDomain.eq(true)); + + if let Some(val) = ¶ms.name { + select = select.filter(db_project::Column::Name.eq(val)); + } + + if let Some(val) = ¶ms.ids + && !val.is_empty() + { + select = select.filter(db_project::Column::Id.is_in(val)); + } + + Ok(select.cursor_by(db_project::Column::Id)) +} + +pub async fn list( + db: &DatabaseConnection, + params: &DomainListParameters, +) -> Result, ResourceDatabaseError> { + get_list_query(params)? + .all(db) + .await + .context("listing domains")? + .into_iter() + .map(TryInto::try_into) + .collect() +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, QueryOrder, Transaction, sea_query::*}; + + use super::super::tests::*; + use super::*; + + #[tokio::test] + async fn test_query_all() { + assert_eq!( + r#"SELECT "project"."id", "project"."name", "project"."extra", "project"."description", "project"."enabled", "project"."domain_id", "project"."parent_id", "project"."is_domain" FROM "project" WHERE "project"."is_domain" = TRUE"#, + QueryOrder::query(&mut get_list_query(&DomainListParameters::default()).unwrap()) + .to_string(PostgresQueryBuilder) + ); + } + + #[tokio::test] + async fn test_query_name() { + assert!( + QueryOrder::query( + &mut get_list_query(&DomainListParameters { + name: Some("name".into()), + ..Default::default() + }) + .unwrap() + ) + .to_string(PostgresQueryBuilder) + .contains("\"project\".\"name\" = 'name'") + ); + } + + #[tokio::test] + async fn test_query_ids() { + let q = QueryOrder::query( + &mut get_list_query(&DomainListParameters { + ids: Some(std::collections::HashSet::from([ + "1".to_string(), + "2".to_string(), + ])), + ..Default::default() + }) + .unwrap(), + ) + .to_string(PostgresQueryBuilder); + assert!(q.contains("\"project\".\"id\" IN ('"), "{}", q); + } + + #[tokio::test] + async fn test_list() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_domain_mock("pid1")]]) + .into_connection(); + + assert_eq!( + list(&db, &DomainListParameters::default()).await.unwrap(), + vec![Domain { + description: None, + enabled: true, + extra: None, + id: "pid1".into(), + name: "name".into(), + }] + ); + + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "project"."id", "project"."name", "project"."extra", "project"."description", "project"."enabled", "project"."domain_id", "project"."parent_id", "project"."is_domain" FROM "project" WHERE "project"."is_domain" = $1 ORDER BY "project"."id" ASC"#, + [true.into()] + ),] + ); + } +} diff --git a/src/resource/backend/sql/domain/mod.rs b/src/resource/backend/sql/domain/mod.rs index f62140a2..5c2a588f 100644 --- a/src/resource/backend/sql/domain/mod.rs +++ b/src/resource/backend/sql/domain/mod.rs @@ -15,9 +15,10 @@ use serde_json::Value; use tracing::error; mod get; +mod list; -pub use get::get_domain_by_id; -pub use get::get_domain_by_name; +pub use get::{get_domain_by_id, get_domain_by_name, get_domain_enabled}; +pub use list::list; use crate::db::entity::project as db_project; use crate::resource::backend::error::ResourceDatabaseError; @@ -34,8 +35,10 @@ impl TryFrom for Domain { if let Some(description) = &value.description { domain_builder.description(description.clone()); } - domain_builder.enabled(value.enabled.unwrap_or(false)); - if let Some(extra) = &value.extra { + domain_builder.enabled(value.enabled.unwrap_or(true)); + if let Some(extra) = &value.extra + && "{}" != extra + { domain_builder.extra( serde_json::from_str::(extra) .inspect_err(|e| error!("failed to deserialize domain extra: {e}")) @@ -46,3 +49,21 @@ impl TryFrom for Domain { Ok(domain_builder.build()?) } } + +#[cfg(test)] +pub mod tests { + use super::*; + + pub fn get_domain_mock>(id: S) -> db_project::Model { + db_project::Model { + description: None, + domain_id: "did".into(), + enabled: Some(true), + extra: Some("{}".to_string()), + id: id.into(), + is_domain: true, + name: "name".into(), + parent_id: None, + } + } +} diff --git a/src/resource/backend/sql/project/list.rs b/src/resource/backend/sql/project/list.rs index 62a74bba..94d055bf 100644 --- a/src/resource/backend/sql/project/list.rs +++ b/src/resource/backend/sql/project/list.rs @@ -22,7 +22,7 @@ use crate::error::DbContextExt; use crate::resource::backend::error::ResourceDatabaseError; use crate::resource::types::*; -/// Prepare the paginated query for listing mappings. +/// Prepare the paginated query for listing projects. fn get_list_query( params: &ProjectListParameters, ) -> Result>, ResourceDatabaseError> { diff --git a/src/resource/error.rs b/src/resource/error.rs index 66826bb7..99161d81 100644 --- a/src/resource/error.rs +++ b/src/resource/error.rs @@ -29,6 +29,7 @@ pub enum ResourceProviderError { #[error("conflict: {0}")] Conflict(String), + /// Domain not found. #[error("domain {0} not found")] DomainNotFound(String), diff --git a/src/resource/mock.rs b/src/resource/mock.rs index e317f786..5346ccef 100644 --- a/src/resource/mock.rs +++ b/src/resource/mock.rs @@ -29,6 +29,12 @@ mock! { #[async_trait] impl ResourceApi for ResourceProvider { + async fn get_domain_enabled<'a>( + &self, + state: &ServiceState, + domain_id: &'a str, + ) -> Result; + async fn create_project( &self, state: &ServiceState, @@ -66,6 +72,12 @@ mock! { project_id: &'a str, ) -> Result>, ResourceProviderError>; + async fn list_domains( + &self, + state: &ServiceState, + params: &DomainListParameters, + ) -> Result, ResourceProviderError>; + async fn list_projects( &self, state: &ServiceState, diff --git a/src/resource/mod.rs b/src/resource/mod.rs index ce66b6ab..b054678b 100644 --- a/src/resource/mod.rs +++ b/src/resource/mod.rs @@ -81,6 +81,17 @@ impl ResourceProvider { #[async_trait] impl ResourceApi for ResourceProvider { + /// Check whether the domain is enabled. + async fn get_domain_enabled<'a>( + &self, + state: &ServiceState, + domain_id: &'a str, + ) -> Result { + self.backend_driver + .get_domain_enabled(state, domain_id) + .await + } + /// Create new project. #[tracing::instrument(level = "info", skip(self, state))] async fn create_project( @@ -154,6 +165,16 @@ impl ResourceApi for ResourceProvider { .await } + /// List domains. + #[tracing::instrument(level = "info", skip(self, state))] + async fn list_domains( + &self, + state: &ServiceState, + params: &DomainListParameters, + ) -> Result, ResourceProviderError> { + self.backend_driver.list_domains(state, params).await + } + /// List projects. #[tracing::instrument(level = "info", skip(self, state))] async fn list_projects( diff --git a/src/resource/types/domain.rs b/src/resource/types/domain.rs index 50354fc9..9af9cc00 100644 --- a/src/resource/types/domain.rs +++ b/src/resource/types/domain.rs @@ -15,22 +15,47 @@ use derive_builder::Builder; use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::collections::HashSet; +use validator::Validate; use crate::error::BuilderError; -#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, Validate)] #[builder(build_fn(error = "BuilderError"))] #[builder(setter(strip_option, into))] pub struct Domain { + /// The resource description. + #[builder(default)] + #[validate(length(min = 1, max = 255))] + pub description: Option, + + /// If set to true, domain is enabled. If set to false, domain is disabled. + pub enabled: bool, + /// The domain ID. + #[validate(length(min = 1, max = 64))] pub id: String, + /// The domain name. + #[validate(length(min = 1, max = 255))] pub name: String, - pub enabled: bool, - /// The resource description. - #[builder(default)] - pub description: Option, + /// Additional domain properties. #[builder(default)] pub extra: Option, } + +/// Domain listing parameters. +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, Validate)] +#[builder(build_fn(error = "BuilderError"))] +pub struct DomainListParameters { + /// Filter domains by the `id` attribute. Items are treated as `IN[]`. + #[builder(default)] + #[validate(length(min = 1, max = 64))] + pub ids: Option>, + + /// Filter domains by the `name` attribute. + #[builder(default)] + #[validate(length(max = 255))] + pub name: Option, +} diff --git a/src/resource/types/project.rs b/src/resource/types/project.rs index cd3d2ac1..7cc1b151 100644 --- a/src/resource/types/project.rs +++ b/src/resource/types/project.rs @@ -42,7 +42,7 @@ pub struct Project { pub extra: Option, /// The project ID. - #[validate(length(min = 1, max = 255))] + #[validate(length(min = 1, max = 64))] pub id: String, /// Indicates whether the project also acts as a domain. If set to true, diff --git a/src/resource/types/provider_api.rs b/src/resource/types/provider_api.rs index 4977e621..42d68592 100644 --- a/src/resource/types/provider_api.rs +++ b/src/resource/types/provider_api.rs @@ -16,37 +16,41 @@ use async_trait::async_trait; use crate::keystone::ServiceState; use crate::resource::ResourceProviderError; -use crate::resource::types::domain::Domain; +use crate::resource::types::domain::*; use crate::resource::types::project::*; /// Resource API. #[async_trait] pub trait ResourceApi: Send + Sync { - /// Create new project. + /// Check whether the domain is enabled. + async fn get_domain_enabled<'a>( + &self, + state: &ServiceState, + domain_id: &'a str, + ) -> Result; + + /// Create a new project. async fn create_project( &self, state: &ServiceState, project: ProjectCreate, ) -> Result; + /// Get a domain by the `id`. async fn get_domain<'a>( &self, state: &ServiceState, domain_id: &'a str, ) -> Result, ResourceProviderError>; - async fn find_domain_by_name<'a>( - &self, - state: &ServiceState, - domain_name: &'a str, - ) -> Result, ResourceProviderError>; - + /// Get a project by the `id`. async fn get_project<'a>( &self, state: &ServiceState, project_id: &'a str, ) -> Result, ResourceProviderError>; + /// Get a project by the `name` and the `domain_id`. async fn get_project_by_name<'a>( &self, state: &ServiceState, @@ -54,13 +58,27 @@ pub trait ResourceApi: Send + Sync { domain_id: &'a str, ) -> Result, ResourceProviderError>; - /// Get project parents + /// Get project parents. async fn get_project_parents<'a>( &self, state: &ServiceState, project_id: &'a str, ) -> Result>, ResourceProviderError>; + /// Find domain by the `name`. + async fn find_domain_by_name<'a>( + &self, + state: &ServiceState, + domain_name: &'a str, + ) -> Result, ResourceProviderError>; + + /// List domains. + async fn list_domains( + &self, + state: &ServiceState, + params: &DomainListParameters, + ) -> Result, ResourceProviderError>; + /// List projects. async fn list_projects( &self, diff --git a/src/token/error.rs b/src/token/error.rs index 67512086..6e0c4e23 100644 --- a/src/token/error.rs +++ b/src/token/error.rs @@ -205,6 +205,10 @@ pub enum TokenProviderError { #[error(transparent)] TrustProvider(#[from] crate::trust::TrustError), + /// The user domain of the trustee is disabled. + #[error("trustee domain disabled")] + TrustorDomainDisabled, + /// Integer conversion error. #[error("int parse")] TryFromIntError(#[from] TryFromIntError), @@ -217,6 +221,14 @@ pub enum TokenProviderError { #[error("user disabled")] UserDisabled(String), + /// The user domain is disabled. + #[error("user domain disabled")] + UserDomainDisabled, + + /// The user is not trustee. + #[error("the token subject user is not trustee of the trust")] + UserIsNotTrustee, + /// The user cannot be found error. #[error("user cannot be found: {0}")] UserNotFound(String), diff --git a/src/token/types.rs b/src/token/types.rs index 123376c9..01c484d2 100644 --- a/src/token/types.rs +++ b/src/token/types.rs @@ -18,9 +18,9 @@ use serde::Serialize; use validator::Validate; use crate::assignment::types::Role; -use crate::identity::types::UserResponse; +use crate::identity::{IdentityApi, types::UserResponse}; use crate::keystone::ServiceState; -use crate::resource::types::{Domain, Project}; +use crate::resource::{ResourceApi, types::*}; use crate::token::error::TokenProviderError; use crate::trust::TrustApi; @@ -336,7 +336,7 @@ impl Token { Ok(()) } - /// Validate the token issuer. + /// Validate the token subject. /// /// Perform checks for the token subject: /// @@ -344,11 +344,27 @@ impl Token { /// - user domain is enabled /// - application credential is not expired pub async fn validate_subject(&self, state: &ServiceState) -> Result<(), TokenProviderError> { - // The "user" must be active - if !self.user().as_ref().is_some_and(|user| user.enabled) { - return Err(TokenProviderError::UserDisabled(self.user_id().clone())); + let user_domain_id: String; + if let Some(user) = self.user() { + // The "user" must be active + if !user.enabled { + return Err(TokenProviderError::UserDisabled(user.id.clone())); + } + + // Ensure user domain is enabled + if !state + .provider + .get_resource_provider() + .get_domain_enabled(state, &user.domain_id) + .await? + { + return Err(TokenProviderError::UserDomainDisabled); + } + + user_domain_id = user.domain_id.clone(); + } else { + return Err(TokenProviderError::SubjectMissing); } - // TODO: User domain must be enabled match self { Token::ApplicationCredential(data) => { @@ -370,6 +386,7 @@ impl Token { Token::Restricted(_data) => {} Token::SystemScope(_data) => {} Token::Trust(data) => { + // Validate the trust chain state .provider .get_trust_provider() @@ -380,6 +397,31 @@ impl Token { .ok_or(TokenProviderError::SubjectMissing)?, ) .await?; + // Validate trustor and trustee + if let Some(trust) = &data.trust { + if data.user_id != trust.trustee_user_id { + return Err(TokenProviderError::UserIsNotTrustee); + } + + // Resolve and verify trustor domain is enabled + let trustor_domain_id = state + .provider + .get_identity_provider() + .get_user_domain_id(state, &trust.trustor_user_id) + .await?; + + if user_domain_id != trustor_domain_id + && !state + .provider + .get_resource_provider() + .get_domain_enabled(state, &trustor_domain_id) + .await? + { + return Err(TokenProviderError::TrustorDomainDisabled); + } + } else { + return Err(TokenProviderError::SubjectMissing); + } } Token::Unscoped(_data) => {} }