Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/identity/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub trait IdentityBackend: Send + Sync {
&self,
state: &ServiceState,
user_id: &'a str,
) -> Result<Option<String>, IdentityProviderError>;
) -> Result<String, IdentityProviderError>;

/// Find federated user by IDP and Unique ID.
async fn find_federated_user<'a>(
Expand Down
2 changes: 1 addition & 1 deletion src/identity/backend/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl IdentityBackend for SqlBackend {
&self,
state: &ServiceState,
user_id: &'a str,
) -> Result<Option<String>, IdentityProviderError> {
) -> Result<String, IdentityProviderError> {
Ok(user::get_user_domain_id(&state.db, user_id).await?)
}

Expand Down
17 changes: 7 additions & 10 deletions src/identity/backend/sql/user/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,17 @@ pub async fn get(
pub async fn get_user_domain_id<U: AsRef<str>>(
db: &DatabaseConnection,
user_id: U,
) -> Result<Option<String>, IdentityDatabaseError> {
Ok(DbUser::find_by_id(user_id.as_ref())
) -> Result<String, IdentityDatabaseError> {
DbUser::find_by_id(user_id.as_ref())
.select_only()
.column(db_user::Column::DomainId)
.into_tuple()
.one(db)
.await
.context("fetching domain_id of a user by ID")?)
.context("fetching domain_id of a user by ID")?
.ok_or(IdentityDatabaseError::UserNotFound(
user_id.as_ref().to_string(),
))
}

#[cfg(test)]
Expand Down Expand Up @@ -216,13 +219,7 @@ mod tests {
]])
.into_connection();

assert_eq!(
get_user_domain_id(&db, "uid")
.await
.unwrap()
.expect("found"),
"did"
);
assert_eq!(get_user_domain_id(&db, "uid").await.unwrap(), "did");
assert_eq!(
db.into_transaction_log(),
[Transaction::from_sql_and_values(
Expand Down
2 changes: 1 addition & 1 deletion src/identity/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mock! {
&self,
state: &ServiceState,
user_id: &'a str,
) -> Result<Option<String>, IdentityProviderError>;
) -> Result<String, IdentityProviderError>;

async fn find_federated_user<'a>(
&self,
Expand Down
39 changes: 12 additions & 27 deletions src/identity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,21 +267,19 @@ impl IdentityApi for IdentityProvider {
&self,
state: &ServiceState,
user_id: &'a str,
) -> Result<Option<String>, IdentityProviderError> {
) -> Result<String, IdentityProviderError> {
if self.caching {
if let Some(domain_id) = self.user_id_domain_id_cache.read().await.get(user_id) {
return Ok(Some(domain_id.clone()));
return Ok(domain_id.clone());
} else {
let domain_id = self
.backend_driver
.get_user_domain_id(state, user_id)
.await?;
if let Some(did) = &domain_id {
self.user_id_domain_id_cache
.write()
.await
.insert(user_id.to_string(), did.clone());
}
self.user_id_domain_id_cache
.write()
.await
.insert(user_id.to_string(), domain_id.clone());
return Ok(domain_id);
}
} else {
Expand Down Expand Up @@ -498,45 +496,32 @@ mod tests {
.expect_get_user_domain_id()
.withf(|_, uid: &'_ str| uid == "uid")
.times(2) // only 2 times
.returning(|_, _| Ok(Some("did".into())));
.returning(|_, _| Ok("did".into()));
backend
.expect_get_user_domain_id()
.withf(|_, uid: &'_ str| uid == "missing")
.returning(|_, _| Ok(None));
.returning(|_, _| Err(IdentityProviderError::UserNotFound("missing".into())));
let mut provider = IdentityProvider::from_driver(backend);
provider.caching = true;

assert_eq!(
provider
.get_user_domain_id(&state, "uid")
.await
.unwrap()
.expect("domain_id should be there"),
provider.get_user_domain_id(&state, "uid").await.unwrap(),
"did"
);
assert_eq!(
provider
.get_user_domain_id(&state, "uid")
.await
.unwrap()
.expect("domain_id should be there"),
provider.get_user_domain_id(&state, "uid").await.unwrap(),
"did",
"second time data extracted from cache"
);
assert!(
provider
.get_user_domain_id(&state, "missing")
.await
.unwrap()
.is_none()
.is_err()
);
provider.caching = false;
assert_eq!(
provider
.get_user_domain_id(&state, "uid")
.await
.unwrap()
.expect("domain_id should be there"),
provider.get_user_domain_id(&state, "uid").await.unwrap(),
"did",
"third time backend is again triggered causing total of 2 invocations"
);
Expand Down
2 changes: 1 addition & 1 deletion src/identity/types/provider_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub trait IdentityApi: Send + Sync {
&self,
state: &ServiceState,
user_id: &'a str,
) -> Result<Option<String>, IdentityProviderError>;
) -> Result<String, IdentityProviderError>;

async fn find_federated_user<'a>(
&self,
Expand Down
14 changes: 14 additions & 0 deletions src/resource/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ use crate::resource::types::*;
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait ResourceBackend: Send + Sync {
/// Get `enabled` field of the domain.
async fn get_domain_enabled<'a>(
&self,
state: &ServiceState,
domain_id: &'a str,
) -> Result<bool, ResourceProviderError>;

/// Create new project.
async fn create_project(
&self,
Expand Down Expand Up @@ -68,6 +75,13 @@ pub trait ResourceBackend: Send + Sync {
project_id: &'a str,
) -> Result<Option<Vec<Project>>, ResourceProviderError>;

/// List domains.
async fn list_domains(
&self,
state: &ServiceState,
params: &DomainListParameters,
) -> Result<Vec<Domain>, ResourceProviderError>;

/// List projects.
async fn list_projects(
&self,
Expand Down
1 change: 1 addition & 0 deletions src/resource/backend/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum ResourceDatabaseError {
source: DatabaseError,
},

/// Domain not found.
#[error("{0}")]
DomainNotFound(String),

Expand Down
152 changes: 27 additions & 125 deletions src/resource/backend/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,18 @@ pub struct SqlBackend {}

#[async_trait]
impl ResourceBackend for SqlBackend {
/// Get `enabled` property of a domain.
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_domain_enabled<'a>(
&self,
state: &ServiceState,
domain_id: &'a str,
) -> Result<bool, ResourceProviderError> {
Ok(domain::get_domain_enabled(&state.db, domain_id).await?)
}

/// Create new project.
#[tracing::instrument(level = "info", skip(self, state))]
#[tracing::instrument(level = "debug", skip(self, state))]
async fn create_project(
&self,
state: &ServiceState,
Expand All @@ -38,6 +48,7 @@ impl ResourceBackend for SqlBackend {
}

/// Get single domain by ID
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_domain<'a>(
&self,
state: &ServiceState,
Expand All @@ -47,6 +58,7 @@ impl ResourceBackend for SqlBackend {
}

/// Get single domain by Name
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_domain_by_name<'a>(
&self,
state: &ServiceState,
Expand All @@ -56,6 +68,7 @@ impl ResourceBackend for SqlBackend {
}

/// Get single project by ID
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_project<'a>(
&self,
state: &ServiceState,
Expand All @@ -65,6 +78,7 @@ impl ResourceBackend for SqlBackend {
}

/// Get single project by Name and Domain ID
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_project_by_name<'a>(
&self,
state: &ServiceState,
Expand All @@ -75,6 +89,7 @@ impl ResourceBackend for SqlBackend {
}

/// Get project parents
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_project_parents<'a>(
&self,
state: &ServiceState,
Expand All @@ -83,7 +98,18 @@ impl ResourceBackend for SqlBackend {
Ok(project::get_project_parents(&state.db, project_id).await?)
}

/// List domains.
#[tracing::instrument(level = "info", skip(self, state))]
async fn list_domains(
&self,
state: &ServiceState,
params: &DomainListParameters,
) -> Result<Vec<Domain>, ResourceProviderError> {
Ok(domain::list(&state.db, params).await?)
}

/// List projects.
#[tracing::instrument(level = "info", skip(self, state))]
async fn list_projects(
&self,
state: &ServiceState,
Expand All @@ -92,127 +118,3 @@ impl ResourceBackend for SqlBackend {
Ok(project::list(&state.db, params).await?)
}
}

//#[cfg(test)]
//mod tests {
// #![allow(clippy::derivable_impls)]
// use chrono::Local;
// use sea_orm::{DatabaseBackend, MockDatabase, Transaction};
//
// use crate::db::entity::{local_user, password, user, user_option};
// use crate::identity::Config;
//
// use super::*;
//
// fn get_user_mock(user_id: String) -> user::Model {
// user::Model {
// id: user_id.clone(),
// domain_id: "foo_domain".into(),
// enabled: Some(true),
// ..Default::default()
// }
// }
//
// fn get_local_user_with_password_mock(
// user_id: String,
// cnt_password: usize,
// ) -> Vec<(local_user::Model, password::Model)> {
// let lu = local_user::Model {
// user_id: user_id.clone(),
// domain_id: "foo_domain".into(),
// name: "Apple Cake".to_owned(),
// ..Default::default()
// };
// let mut passwords: Vec<password::Model> = Vec::new();
// for i in 0..cnt_password {
// passwords.push(password::Model {
// id: i as i32,
// local_user_id: 1,
// expires_at: None,
// self_service: false,
// password_hash: None,
// created_at: Local::now().naive_utc(),
// created_at_int: 12345,
// expires_at_int: None,
// });
// }
// passwords
// .into_iter()
// .map(|x| (lu.clone(), x.clone()))
// .collect()
// }
//
// #[tokio::test]
// async fn test_get_user_local() {
// // Create MockDatabase with mock query results
// let db = MockDatabase::new(DatabaseBackend::Postgres)
// .append_query_results([
// // First query result - select user itself
// vec![get_user_mock("1".into())],
// ])
// .append_query_results([
// //// Second query result - user options
// vec![user_option::Model {
// user_id: "1".into(),
// option_id: "1000".into(),
// option_value: Some("true".into()),
// }],
// ])
// .append_query_results([
// // Third query result - local user with passwords
// get_local_user_with_password_mock("1".into(), 1),
// ])
// .into_connection();
// let config = Config::default();
// assert_eq!(
// get_user(&config, &db, "1".into()).await.unwrap().unwrap(),
// User {
// id: "1".into(),
// domain_id: "foo_domain".into(),
// name: "Apple Cake".to_owned(),
// enabled: true,
// options: UserOptions {
// ignore_change_password_upon_first_use: Some(true),
// ..Default::default()
// },
// ..Default::default()
// }
// );
//
// // Checking transaction log
// assert_eq!(
// db.into_transaction_log(),
// [
// Transaction::from_sql_and_values(
// DatabaseBackend::Postgres,
// r#"SELECT "user"."id", "user"."extra", "user"."enabled",
// "user"."default_project_id", "user"."created_at", "user"."last_active_at",
// "user"."domain_id" FROM "user" WHERE "user"."id" = $1 LIMIT $2"#,
// ["1".into(), 1u64.into()] ),
// Transaction::from_sql_and_values(
// DatabaseBackend::Postgres,
// r#"SELECT "user_option"."user_id",
// "user_option"."option_id", "user_option"."option_value" FROM "user_option"
// INNER JOIN "user" ON "user"."id" = "user_option"."user_id" WHERE "user"."id"
// = $1"#, ["1".into()]
// ),
// Transaction::from_sql_and_values(
// DatabaseBackend::Postgres,
// r#"SELECT "local_user"."id" AS "A_id",
// "local_user"."user_id" AS "A_user_id", "local_user"."domain_id" AS
// "A_domain_id", "local_user"."name" AS "A_name",
// "local_user"."failed_auth_count" AS "A_failed_auth_count",
// "local_user"."failed_auth_at" AS "A_failed_auth_at", "password"."id" AS
// "B_id", "password"."local_user_id" AS "B_local_user_id",
// "password"."self_service" AS "B_self_service", "password"."created_at" AS
// "B_created_at", "password"."expires_at" AS "B_expires_at",
// "password"."password_hash" AS "B_password_hash", "password"."created_at_int"
// AS "B_created_at_int", "password"."expires_at_int" AS "B_expires_at_int" FROM
// "local_user" LEFT JOIN "password" ON "local_user"."id" =
// "password"."local_user_id" WHERE "local_user"."user_id" = $1 ORDER BY
// "local_user"."id" ASC"#, ["1".into()]
// ),
// ]
// );
// }
//}
Loading
Loading