mirror of
https://github.com/lldap/lldap.git
synced 2026-03-31 15:07:48 +01:00
server: extract the sql backend handler to a separate crate
This commit is contained in:
committed by
nitnelave
parent
ee21d83056
commit
55de3ac329
@@ -132,6 +132,12 @@ pub mod server {
|
||||
pub use super::*;
|
||||
pub type ServerRegistration = opaque_ke::ServerRegistration<DefaultSuite>;
|
||||
pub type ServerSetup = opaque_ke::ServerSetup<DefaultSuite>;
|
||||
|
||||
pub fn generate_random_private_key() -> ServerSetup {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
ServerSetup::new(&mut rng)
|
||||
}
|
||||
|
||||
/// Methods to register a new user, from the server side.
|
||||
pub mod registration {
|
||||
pub use super::*;
|
||||
|
||||
@@ -61,7 +61,6 @@ mod tests {
|
||||
use lldap_domain_handlers::handler::{GroupRequestFilter, UserRequestFilter};
|
||||
use lldap_test_utils::MockTestBackendHandler;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compare_user() {
|
||||
|
||||
@@ -172,7 +172,6 @@ mod tests {
|
||||
use lldap_test_utils::MockTestBackendHandler;
|
||||
use mockall::predicate::eq;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_user() {
|
||||
|
||||
@@ -115,7 +115,6 @@ mod tests {
|
||||
use lldap_test_utils::MockTestBackendHandler;
|
||||
use mockall::predicate::eq;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_user() {
|
||||
|
||||
@@ -140,7 +140,6 @@ mod tests {
|
||||
use mockall::predicate::eq;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashSet;
|
||||
|
||||
|
||||
fn setup_target_user_groups(
|
||||
mock: &mut MockTestBackendHandler,
|
||||
|
||||
@@ -333,7 +333,6 @@ mod tests {
|
||||
use lldap_test_utils::MockTestBackendHandler;
|
||||
use mockall::predicate::eq;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_root_dse() {
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
[package]
|
||||
name = "lldap_sql_backend_handler"
|
||||
version = "0.1.0"
|
||||
description = "SQL backend for LLDAP"
|
||||
authors.workspace = true
|
||||
edition.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[features]
|
||||
test = []
|
||||
|
||||
[dependencies]
|
||||
anyhow = "*"
|
||||
async-trait = "0.1"
|
||||
base64 = "0.21"
|
||||
bincode = "1.3"
|
||||
itertools = "0.10"
|
||||
ldap3_proto = "0.6.0"
|
||||
orion = "0.17"
|
||||
serde_json = "1"
|
||||
tracing = "*"
|
||||
|
||||
[dependencies.chrono]
|
||||
features = ["serde"]
|
||||
version = "*"
|
||||
|
||||
[dependencies.rand]
|
||||
features = ["small_rng", "getrandom"]
|
||||
version = "0.8"
|
||||
|
||||
[dependencies.sea-orm]
|
||||
workspace = true
|
||||
features = [
|
||||
"macros",
|
||||
"with-chrono",
|
||||
"with-uuid",
|
||||
"sqlx-all",
|
||||
"runtime-actix-rustls",
|
||||
]
|
||||
|
||||
[dependencies.secstr]
|
||||
features = ["serde"]
|
||||
version = "*"
|
||||
|
||||
[dependencies.serde]
|
||||
workspace = true
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = ["v1", "v3"]
|
||||
|
||||
[dependencies.lldap_access_control]
|
||||
path = "../access-control"
|
||||
|
||||
[dependencies.lldap_auth]
|
||||
path = "../auth"
|
||||
features = ["opaque_server", "opaque_client", "sea_orm"]
|
||||
|
||||
[dependencies.lldap_domain]
|
||||
path = "../domain"
|
||||
|
||||
[dependencies.lldap_domain_handlers]
|
||||
path = "../domain-handlers"
|
||||
|
||||
[dependencies.lldap_domain_model]
|
||||
path = "../domain-model"
|
||||
|
||||
[dependencies.lldap_opaque_handler]
|
||||
path = "../opaque-handler"
|
||||
|
||||
[dev-dependencies.lldap_test_utils]
|
||||
path = "../test-utils"
|
||||
|
||||
[dev-dependencies]
|
||||
log = "*"
|
||||
mockall = "0.11.4"
|
||||
pretty_assertions = "1"
|
||||
|
||||
[dev-dependencies.tokio]
|
||||
features = ["full"]
|
||||
version = "1.25"
|
||||
|
||||
[dev-dependencies.tracing-subscriber]
|
||||
version = "0.3"
|
||||
features = ["env-filter", "tracing-log"]
|
||||
@@ -0,0 +1,11 @@
|
||||
pub(crate) mod logging;
|
||||
pub(crate) mod sql_backend_handler;
|
||||
pub(crate) mod sql_group_backend_handler;
|
||||
pub(crate) mod sql_opaque_handler;
|
||||
pub(crate) mod sql_schema_backend_handler;
|
||||
pub(crate) mod sql_user_backend_handler;
|
||||
|
||||
pub use sql_backend_handler::SqlBackendHandler;
|
||||
pub use sql_opaque_handler::register_password;
|
||||
pub mod sql_migrations;
|
||||
pub mod sql_tables;
|
||||
@@ -0,0 +1,10 @@
|
||||
#[cfg(test)]
|
||||
pub fn init_for_tests() {
|
||||
if let Err(e) = tracing_subscriber::FmtSubscriber::builder()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_test_writer()
|
||||
.try_init()
|
||||
{
|
||||
log::warn!("Could not set up test logging: {:#}", e);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
use crate::sql_tables::DbConnection;
|
||||
use async_trait::async_trait;
|
||||
use lldap_auth::opaque::server::ServerSetup;
|
||||
use lldap_domain_handlers::handler::BackendHandler;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SqlBackendHandler {
|
||||
pub(crate) opaque_setup: ServerSetup,
|
||||
pub(crate) sql_pool: DbConnection,
|
||||
}
|
||||
|
||||
impl SqlBackendHandler {
|
||||
pub fn new(opaque_setup: ServerSetup, sql_pool: DbConnection) -> Self {
|
||||
SqlBackendHandler {
|
||||
opaque_setup,
|
||||
sql_pool,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pool(&self) -> &DbConnection {
|
||||
&self.sql_pool
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackendHandler for SqlBackendHandler {}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
use crate::sql_tables::init_table;
|
||||
use lldap_auth::{
|
||||
opaque::{self, server::generate_random_private_key},
|
||||
registration,
|
||||
};
|
||||
use lldap_domain::{
|
||||
requests::{CreateGroupRequest, CreateUserRequest},
|
||||
types::{Attribute as DomainAttribute, GroupId, UserId},
|
||||
};
|
||||
use lldap_domain_handlers::handler::{
|
||||
GroupBackendHandler, UserBackendHandler, UserListerBackendHandler, UserRequestFilter,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use sea_orm::Database;
|
||||
|
||||
pub async fn get_in_memory_db() -> DbConnection {
|
||||
crate::logging::init_for_tests();
|
||||
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
|
||||
sql_opt
|
||||
.max_connections(1)
|
||||
.sqlx_logging(true)
|
||||
.sqlx_logging_level(log::LevelFilter::Debug);
|
||||
Database::connect(sql_opt).await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn get_initialized_db() -> DbConnection {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
sql_pool
|
||||
}
|
||||
|
||||
pub async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) {
|
||||
use lldap_opaque_handler::OpaqueHandler;
|
||||
insert_user_no_password(handler, name).await;
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let client_registration_start =
|
||||
opaque::client::registration::start_registration(pass.as_bytes(), &mut rng).unwrap();
|
||||
let response = handler
|
||||
.registration_start(registration::ClientRegistrationStartRequest {
|
||||
username: name.into(),
|
||||
registration_start_request: client_registration_start.message,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let registration_upload = opaque::client::registration::finish_registration(
|
||||
client_registration_start.state,
|
||||
response.registration_response,
|
||||
&mut rng,
|
||||
)
|
||||
.unwrap();
|
||||
handler
|
||||
.registration_finish(registration::ClientRegistrationFinishRequest {
|
||||
server_data: response.server_data,
|
||||
registration_upload: registration_upload.message,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
|
||||
handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: UserId::new(name),
|
||||
email: format!("{}@bob.bob", name).into(),
|
||||
display_name: Some("display ".to_string() + name),
|
||||
attributes: vec![
|
||||
DomainAttribute {
|
||||
name: "first_name".into(),
|
||||
value: ("first ".to_string() + name).into(),
|
||||
},
|
||||
DomainAttribute {
|
||||
name: "last_name".into(),
|
||||
value: ("last ".to_string() + name).into(),
|
||||
},
|
||||
],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId {
|
||||
handler
|
||||
.create_group(CreateGroupRequest {
|
||||
display_name: name.into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
|
||||
handler
|
||||
.add_user_to_group(&UserId::new(user_id), group_id)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn get_user_names(
|
||||
handler: &SqlBackendHandler,
|
||||
filters: Option<UserRequestFilter>,
|
||||
) -> Vec<String> {
|
||||
handler
|
||||
.list_users(filters, false)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user.user_id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub struct TestFixture {
|
||||
pub handler: SqlBackendHandler,
|
||||
pub groups: Vec<GroupId>,
|
||||
}
|
||||
|
||||
impl TestFixture {
|
||||
pub async fn new() -> Self {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool);
|
||||
insert_user_no_password(&handler, "bob").await;
|
||||
insert_user_no_password(&handler, "patrick").await;
|
||||
insert_user_no_password(&handler, "John").await;
|
||||
insert_user_no_password(&handler, "NoGroup").await;
|
||||
let mut groups = vec![];
|
||||
groups.push(insert_group(&handler, "Best Group").await);
|
||||
groups.push(insert_group(&handler, "Worst Group").await);
|
||||
groups.push(insert_group(&handler, "Empty Group").await);
|
||||
insert_membership(&handler, groups[0], "bob").await;
|
||||
insert_membership(&handler, groups[0], "patrick").await;
|
||||
insert_membership(&handler, groups[1], "patrick").await;
|
||||
insert_membership(&handler, groups[1], "John").await;
|
||||
Self { handler, groups }
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_injection() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool);
|
||||
let user_name = UserId::new(r#"bob"e"i'o;aü"#);
|
||||
insert_user_no_password(&handler, user_name.as_str()).await;
|
||||
{
|
||||
let users = handler
|
||||
.list_users(None, false)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(users, vec![user_name.clone()]);
|
||||
let user = handler.get_user_details(&user_name).await.unwrap();
|
||||
assert_eq!(user.user_id, user_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,657 @@
|
||||
use crate::sql_backend_handler::SqlBackendHandler;
|
||||
use async_trait::async_trait;
|
||||
use lldap_access_control::UserReadableBackendHandler;
|
||||
use lldap_domain::{
|
||||
requests::{CreateGroupRequest, UpdateGroupRequest},
|
||||
types::{AttributeName, Group, GroupDetails, GroupId, Serialized, Uuid},
|
||||
};
|
||||
use lldap_domain_handlers::handler::{
|
||||
GroupBackendHandler, GroupListerBackendHandler, GroupRequestFilter,
|
||||
};
|
||||
use lldap_domain_model::{
|
||||
error::{DomainError, Result},
|
||||
model::{self, GroupColumn, MembershipColumn, deserialize},
|
||||
};
|
||||
use sea_orm::{
|
||||
ActiveModelTrait, ColumnTrait, DatabaseTransaction, EntityTrait, QueryFilter, QueryOrder,
|
||||
QuerySelect, QueryTrait, Set, TransactionTrait,
|
||||
sea_query::{Alias, Cond, Expr, Func, IntoCondition, OnConflict, SimpleExpr},
|
||||
};
|
||||
use tracing::instrument;
|
||||
|
||||
fn attribute_condition(name: AttributeName, value: Option<Serialized>) -> Cond {
|
||||
Expr::in_subquery(
|
||||
Expr::col(GroupColumn::GroupId.as_column_ref()),
|
||||
model::GroupAttributes::find()
|
||||
.select_only()
|
||||
.column(model::GroupAttributesColumn::GroupId)
|
||||
.filter(model::GroupAttributesColumn::AttributeName.eq(name))
|
||||
.filter(
|
||||
value
|
||||
.map(|value| model::GroupAttributesColumn::Value.eq(value))
|
||||
.unwrap_or_else(|| SimpleExpr::Value(true.into())),
|
||||
)
|
||||
.into_query(),
|
||||
)
|
||||
.into_condition()
|
||||
}
|
||||
|
||||
fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
|
||||
use GroupRequestFilter::*;
|
||||
let group_table = Alias::new("groups");
|
||||
match filter {
|
||||
And(fs) => {
|
||||
if fs.is_empty() {
|
||||
SimpleExpr::Value(true.into()).into_condition()
|
||||
} else {
|
||||
fs.into_iter()
|
||||
.fold(Cond::all(), |c, f| c.add(get_group_filter_expr(f)))
|
||||
}
|
||||
}
|
||||
Or(fs) => {
|
||||
if fs.is_empty() {
|
||||
SimpleExpr::Value(false.into()).into_condition()
|
||||
} else {
|
||||
fs.into_iter()
|
||||
.fold(Cond::any(), |c, f| c.add(get_group_filter_expr(f)))
|
||||
}
|
||||
}
|
||||
Not(f) => get_group_filter_expr(*f).not(),
|
||||
DisplayName(name) => GroupColumn::LowercaseDisplayName
|
||||
.eq(name.as_str().to_lowercase())
|
||||
.into_condition(),
|
||||
GroupId(id) => GroupColumn::GroupId.eq(id.0).into_condition(),
|
||||
Uuid(uuid) => GroupColumn::Uuid.eq(uuid.to_string()).into_condition(),
|
||||
// WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user))
|
||||
Member(user) => GroupColumn::GroupId
|
||||
.in_subquery(
|
||||
model::Membership::find()
|
||||
.select_only()
|
||||
.column(MembershipColumn::GroupId)
|
||||
.filter(MembershipColumn::UserId.eq(user))
|
||||
.into_query(),
|
||||
)
|
||||
.into_condition(),
|
||||
DisplayNameSubString(filter) => SimpleExpr::FunctionCall(Func::lower(Expr::col((
|
||||
group_table,
|
||||
GroupColumn::LowercaseDisplayName,
|
||||
))))
|
||||
.like(filter.to_sql_filter())
|
||||
.into_condition(),
|
||||
AttributeEquality(name, value) => attribute_condition(name, Some(value.into())),
|
||||
CustomAttributePresent(name) => attribute_condition(name, None),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GroupListerBackendHandler for SqlBackendHandler {
|
||||
#[instrument(skip(self), level = "debug", ret, err)]
|
||||
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
|
||||
let filters = filters
|
||||
.map(|f| {
|
||||
GroupColumn::GroupId
|
||||
.in_subquery(
|
||||
model::Group::find()
|
||||
.find_also_linked(model::memberships::GroupToUser)
|
||||
.select_only()
|
||||
.column(GroupColumn::GroupId)
|
||||
.filter(get_group_filter_expr(f))
|
||||
.into_query(),
|
||||
)
|
||||
.into_condition()
|
||||
})
|
||||
.unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition());
|
||||
let results = model::Group::find()
|
||||
.order_by_asc(GroupColumn::GroupId)
|
||||
.find_with_related(model::Membership)
|
||||
.filter(filters.clone())
|
||||
.all(&self.sql_pool)
|
||||
.await?;
|
||||
let mut groups: Vec<_> = results
|
||||
.into_iter()
|
||||
.map(|(group, users)| {
|
||||
let users: Vec<_> = users.into_iter().map(|u| u.user_id).collect();
|
||||
Group {
|
||||
users,
|
||||
..group.into()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// TODO: should be wrapped in a transaction
|
||||
let schema = self.get_schema().await?;
|
||||
let attributes = model::GroupAttributes::find()
|
||||
.filter(
|
||||
model::GroupAttributesColumn::GroupId.in_subquery(
|
||||
model::Group::find()
|
||||
.filter(filters)
|
||||
.select_only()
|
||||
.column(model::groups::Column::GroupId)
|
||||
.into_query(),
|
||||
),
|
||||
)
|
||||
.order_by_asc(model::GroupAttributesColumn::GroupId)
|
||||
.order_by_asc(model::GroupAttributesColumn::AttributeName)
|
||||
.all(&self.sql_pool)
|
||||
.await?;
|
||||
let mut attributes_iter = attributes.into_iter().peekable();
|
||||
use itertools::Itertools; // For take_while_ref
|
||||
for group in groups.iter_mut() {
|
||||
group.attributes = attributes_iter
|
||||
.take_while_ref(|u| u.group_id == group.id)
|
||||
.map(|a| {
|
||||
deserialize::deserialize_attribute(
|
||||
a.attribute_name,
|
||||
&a.value,
|
||||
&schema.get_schema().group_attributes,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
}
|
||||
groups.sort_by(|g1, g2| g1.display_name.cmp(&g2.display_name));
|
||||
Ok(groups)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GroupBackendHandler for SqlBackendHandler {
|
||||
#[instrument(skip(self), level = "debug", ret, err)]
|
||||
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupDetails> {
|
||||
let mut group_details = model::Group::find_by_id(group_id)
|
||||
.one(&self.sql_pool)
|
||||
.await?
|
||||
.map(Into::<GroupDetails>::into)
|
||||
.ok_or_else(|| DomainError::EntityNotFound(format!("{:?}", group_id)))?;
|
||||
let attributes = model::GroupAttributes::find()
|
||||
.filter(model::GroupAttributesColumn::GroupId.eq(group_details.group_id))
|
||||
.order_by_asc(model::GroupAttributesColumn::AttributeName)
|
||||
.all(&self.sql_pool)
|
||||
.await?;
|
||||
let schema = self.get_schema().await?;
|
||||
group_details.attributes = attributes
|
||||
.into_iter()
|
||||
.map(|a| {
|
||||
deserialize::deserialize_attribute(
|
||||
a.attribute_name,
|
||||
&a.value,
|
||||
&schema.get_schema().group_attributes,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(group_details)
|
||||
}
|
||||
|
||||
#[instrument(skip(self), level = "debug", err, fields(group_id = ?request.group_id))]
|
||||
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> {
|
||||
Ok(self
|
||||
.sql_pool
|
||||
.transaction::<_, (), DomainError>(|transaction| {
|
||||
Box::pin(
|
||||
async move { Self::update_group_with_transaction(request, transaction).await },
|
||||
)
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[instrument(skip(self), level = "debug", ret, err)]
|
||||
async fn create_group(&self, request: CreateGroupRequest) -> Result<GroupId> {
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
let uuid = Uuid::from_name_and_date(request.display_name.as_str(), &now);
|
||||
let lower_display_name = request.display_name.as_str().to_lowercase();
|
||||
let new_group = model::groups::ActiveModel {
|
||||
display_name: Set(request.display_name),
|
||||
lowercase_display_name: Set(lower_display_name),
|
||||
creation_date: Set(now),
|
||||
uuid: Set(uuid),
|
||||
..Default::default()
|
||||
};
|
||||
Ok(self
|
||||
.sql_pool
|
||||
.transaction::<_, GroupId, DomainError>(|transaction| {
|
||||
Box::pin(async move {
|
||||
let schema = Self::get_schema_with_transaction(transaction).await?;
|
||||
let group_id = new_group.insert(transaction).await?.group_id;
|
||||
let mut new_group_attributes = Vec::new();
|
||||
for attribute in request.attributes {
|
||||
if schema
|
||||
.group_attributes
|
||||
.get_attribute_type(&attribute.name)
|
||||
.is_some()
|
||||
{
|
||||
new_group_attributes.push(model::group_attributes::ActiveModel {
|
||||
group_id: Set(group_id),
|
||||
attribute_name: Set(attribute.name),
|
||||
value: Set(attribute.value.into()),
|
||||
});
|
||||
} else {
|
||||
return Err(DomainError::InternalError(format!(
|
||||
"Attribute name {} doesn't exist in the group schema,
|
||||
yet was attempted to be inserted in the database",
|
||||
&attribute.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
if !new_group_attributes.is_empty() {
|
||||
model::GroupAttributes::insert_many(new_group_attributes)
|
||||
.exec(transaction)
|
||||
.await?;
|
||||
}
|
||||
Ok(group_id)
|
||||
})
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[instrument(skip(self), level = "debug", err)]
|
||||
async fn delete_group(&self, group_id: GroupId) -> Result<()> {
|
||||
let res = model::Group::delete_by_id(group_id)
|
||||
.exec(&self.sql_pool)
|
||||
.await?;
|
||||
if res.rows_affected == 0 {
|
||||
return Err(DomainError::EntityNotFound(format!(
|
||||
"No such group: '{:?}'",
|
||||
group_id
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SqlBackendHandler {
|
||||
async fn update_group_with_transaction(
|
||||
request: UpdateGroupRequest,
|
||||
transaction: &DatabaseTransaction,
|
||||
) -> Result<()> {
|
||||
let lower_display_name = request
|
||||
.display_name
|
||||
.as_ref()
|
||||
.map(|s| s.as_str().to_lowercase());
|
||||
let update_group = model::groups::ActiveModel {
|
||||
group_id: Set(request.group_id),
|
||||
display_name: request.display_name.map(Set).unwrap_or_default(),
|
||||
lowercase_display_name: lower_display_name.map(Set).unwrap_or_default(),
|
||||
..Default::default()
|
||||
};
|
||||
update_group.update(transaction).await?;
|
||||
let mut update_group_attributes = Vec::new();
|
||||
let mut remove_group_attributes = Vec::new();
|
||||
let schema = Self::get_schema_with_transaction(transaction).await?;
|
||||
for attribute in request.insert_attributes {
|
||||
if schema
|
||||
.group_attributes
|
||||
.get_attribute_type(&attribute.name)
|
||||
.is_some()
|
||||
{
|
||||
update_group_attributes.push(model::group_attributes::ActiveModel {
|
||||
group_id: Set(request.group_id),
|
||||
attribute_name: Set(attribute.name.to_owned()),
|
||||
value: Set(attribute.value.into()),
|
||||
});
|
||||
} else {
|
||||
return Err(DomainError::InternalError(format!(
|
||||
"Group attribute name {} doesn't exist in the schema, yet was attempted to be inserted in the database",
|
||||
&attribute.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
for attribute in request.delete_attributes {
|
||||
if schema
|
||||
.group_attributes
|
||||
.get_attribute_type(&attribute)
|
||||
.is_some()
|
||||
{
|
||||
remove_group_attributes.push(attribute);
|
||||
} else {
|
||||
return Err(DomainError::InternalError(format!(
|
||||
"Group attribute name {} doesn't exist in the schema, yet was attempted to be removed from the database",
|
||||
attribute
|
||||
)));
|
||||
}
|
||||
}
|
||||
if !remove_group_attributes.is_empty() {
|
||||
model::GroupAttributes::delete_many()
|
||||
.filter(model::GroupAttributesColumn::GroupId.eq(request.group_id))
|
||||
.filter(model::GroupAttributesColumn::AttributeName.is_in(remove_group_attributes))
|
||||
.exec(transaction)
|
||||
.await?;
|
||||
}
|
||||
if !update_group_attributes.is_empty() {
|
||||
model::GroupAttributes::insert_many(update_group_attributes)
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
model::GroupAttributesColumn::GroupId,
|
||||
model::GroupAttributesColumn::AttributeName,
|
||||
])
|
||||
.update_column(model::GroupAttributesColumn::Value)
|
||||
.to_owned(),
|
||||
)
|
||||
.exec(transaction)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sql_backend_handler::tests::*;
|
||||
use lldap_domain::{
|
||||
requests::CreateAttributeRequest,
|
||||
types::{Attribute, AttributeType, GroupName, UserId},
|
||||
};
|
||||
use lldap_domain_handlers::handler::{SchemaBackendHandler, SubStringFilter};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
async fn get_group_ids(
|
||||
handler: &SqlBackendHandler,
|
||||
filters: Option<GroupRequestFilter>,
|
||||
) -> Vec<GroupId> {
|
||||
handler
|
||||
.list_groups(filters)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|g| g.id)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
async fn get_group_names(
|
||||
handler: &SqlBackendHandler,
|
||||
filters: Option<GroupRequestFilter>,
|
||||
) -> Vec<GroupName> {
|
||||
handler
|
||||
.list_groups(filters)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|g| g.display_name)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups_no_filter() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_names(&fixture.handler, None).await,
|
||||
vec![
|
||||
"Best Group".into(),
|
||||
"Empty Group".into(),
|
||||
"Worst Group".into()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups_simple_filter() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_names(
|
||||
&fixture.handler,
|
||||
Some(GroupRequestFilter::Or(vec![
|
||||
GroupRequestFilter::DisplayName("Empty Group".into()),
|
||||
GroupRequestFilter::Member(UserId::new("bob")),
|
||||
]))
|
||||
)
|
||||
.await,
|
||||
vec!["Best Group".into(), "Empty Group".into()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups_case_insensitive_filter() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_names(
|
||||
&fixture.handler,
|
||||
Some(GroupRequestFilter::DisplayName("eMpTy gRoup".into()),)
|
||||
)
|
||||
.await,
|
||||
vec!["Empty Group".into()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups_negation() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_ids(
|
||||
&fixture.handler,
|
||||
Some(GroupRequestFilter::And(vec![
|
||||
GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName(
|
||||
"value".into()
|
||||
))),
|
||||
GroupRequestFilter::GroupId(fixture.groups[0]),
|
||||
]))
|
||||
)
|
||||
.await,
|
||||
vec![fixture.groups[0]]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups_substring_filter() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_ids(
|
||||
&fixture.handler,
|
||||
Some(GroupRequestFilter::DisplayNameSubString(SubStringFilter {
|
||||
initial: Some("be".to_owned()),
|
||||
any: vec!["sT".to_owned()],
|
||||
final_: Some("P".to_owned()),
|
||||
})),
|
||||
)
|
||||
.await,
|
||||
// Best group
|
||||
vec![fixture.groups[0]]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups_other_filter() {
|
||||
let fixture = TestFixture::new().await;
|
||||
fixture
|
||||
.handler
|
||||
.add_group_attribute(CreateAttributeRequest {
|
||||
name: "gid".into(),
|
||||
attribute_type: AttributeType::Integer,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
fixture
|
||||
.handler
|
||||
.update_group(UpdateGroupRequest {
|
||||
group_id: fixture.groups[0],
|
||||
display_name: None,
|
||||
delete_attributes: Vec::new(),
|
||||
insert_attributes: vec![Attribute {
|
||||
name: "gid".into(),
|
||||
value: 512.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
get_group_ids(
|
||||
&fixture.handler,
|
||||
Some(GroupRequestFilter::AttributeEquality(
|
||||
AttributeName::from("gid"),
|
||||
512.into(),
|
||||
)),
|
||||
)
|
||||
.await,
|
||||
vec![fixture.groups[0]]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_group_details() {
|
||||
let fixture = TestFixture::new().await;
|
||||
let details = fixture
|
||||
.handler
|
||||
.get_group_details(fixture.groups[0])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(details.group_id, fixture.groups[0]);
|
||||
assert_eq!(details.display_name, "Best Group".into());
|
||||
assert_eq!(
|
||||
get_group_ids(
|
||||
&fixture.handler,
|
||||
Some(GroupRequestFilter::Uuid(details.uuid))
|
||||
)
|
||||
.await,
|
||||
vec![fixture.groups[0]]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_group() {
|
||||
let fixture = TestFixture::new().await;
|
||||
fixture
|
||||
.handler
|
||||
.update_group(UpdateGroupRequest {
|
||||
group_id: fixture.groups[0],
|
||||
display_name: Some("Awesomest Group".into()),
|
||||
delete_attributes: Vec::new(),
|
||||
insert_attributes: Vec::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let details = fixture
|
||||
.handler
|
||||
.get_group_details(fixture.groups[0])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(details.display_name, "Awesomest Group".into());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_group() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_ids(&fixture.handler, None).await,
|
||||
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
|
||||
);
|
||||
fixture
|
||||
.handler
|
||||
.delete_group(fixture.groups[0])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
get_group_ids(&fixture.handler, None).await,
|
||||
vec![fixture.groups[2], fixture.groups[1]]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_group() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
get_group_ids(&fixture.handler, None).await,
|
||||
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
|
||||
);
|
||||
fixture
|
||||
.handler
|
||||
.add_group_attribute(CreateAttributeRequest {
|
||||
name: "new_attribute".into(),
|
||||
attribute_type: AttributeType::String,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let new_group_id = fixture
|
||||
.handler
|
||||
.create_group(CreateGroupRequest {
|
||||
display_name: "New Group".into(),
|
||||
attributes: vec![Attribute {
|
||||
name: "new_attribute".into(),
|
||||
value: "value".to_string().into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let group_details = fixture
|
||||
.handler
|
||||
.get_group_details(new_group_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(group_details.display_name, "New Group".into());
|
||||
assert_eq!(
|
||||
group_details.attributes,
|
||||
vec![Attribute {
|
||||
name: "new_attribute".into(),
|
||||
value: "value".to_string().into(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_group_attributes() {
|
||||
let fixture = TestFixture::new().await;
|
||||
fixture
|
||||
.handler
|
||||
.add_group_attribute(CreateAttributeRequest {
|
||||
name: "new_attribute".into(),
|
||||
attribute_type: AttributeType::Integer,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let group_id = fixture.groups[0];
|
||||
let attributes = vec![Attribute {
|
||||
name: "new_attribute".into(),
|
||||
value: 42i64.into(),
|
||||
}];
|
||||
fixture
|
||||
.handler
|
||||
.update_group(UpdateGroupRequest {
|
||||
group_id,
|
||||
display_name: None,
|
||||
delete_attributes: Vec::new(),
|
||||
insert_attributes: attributes.clone(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let details = fixture.handler.get_group_details(group_id).await.unwrap();
|
||||
assert_eq!(details.attributes, attributes);
|
||||
fixture
|
||||
.handler
|
||||
.update_group(UpdateGroupRequest {
|
||||
group_id,
|
||||
display_name: None,
|
||||
delete_attributes: vec!["new_attribute".into()],
|
||||
insert_attributes: Vec::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let details = fixture.handler.get_group_details(group_id).await.unwrap();
|
||||
assert_eq!(details.attributes, Vec::new());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_group_duplicate_name() {
|
||||
let fixture = TestFixture::new().await;
|
||||
fixture
|
||||
.handler
|
||||
.create_group(CreateGroupRequest {
|
||||
display_name: "New Group".into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
fixture
|
||||
.handler
|
||||
.create_group(CreateGroupRequest {
|
||||
display_name: "neW group".into(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,343 @@
|
||||
use crate::SqlBackendHandler;
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine;
|
||||
use lldap_auth::opaque;
|
||||
use lldap_domain::types::UserId;
|
||||
use lldap_domain_handlers::handler::{BindRequest, LoginHandler};
|
||||
use lldap_domain_model::{
|
||||
error::{DomainError, Result},
|
||||
model::{self, UserColumn},
|
||||
};
|
||||
use lldap_opaque_handler::{OpaqueHandler, login, registration};
|
||||
use sea_orm::{ActiveModelTrait, ActiveValue, EntityTrait, QuerySelect};
|
||||
use secstr::SecUtf8;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
type SqlOpaqueHandler = SqlBackendHandler;
|
||||
|
||||
#[instrument(skip_all, level = "debug", err, fields(username = %username.as_str()))]
|
||||
fn passwords_match(
|
||||
password_file_bytes: &[u8],
|
||||
clear_password: &str,
|
||||
opaque_setup: &opaque::server::ServerSetup,
|
||||
username: &UserId,
|
||||
) -> Result<()> {
|
||||
use opaque::{client, server};
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
|
||||
|
||||
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
|
||||
.map_err(opaque::AuthenticationError::ProtocolError)?;
|
||||
let server_login_start_result = server::login::start_login(
|
||||
&mut rng,
|
||||
opaque_setup,
|
||||
Some(password_file),
|
||||
client_login_start_result.message,
|
||||
username,
|
||||
)?;
|
||||
client::login::finish_login(
|
||||
client_login_start_result.state,
|
||||
server_login_start_result.message,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl SqlBackendHandler {
|
||||
fn get_orion_secret_key(&self) -> Result<orion::aead::SecretKey> {
|
||||
Ok(orion::aead::SecretKey::from_slice(
|
||||
self.opaque_setup.keypair().private(),
|
||||
)?)
|
||||
}
|
||||
|
||||
#[instrument(skip(self), level = "debug", err)]
|
||||
async fn get_password_file_for_user(&self, user_id: UserId) -> Result<Option<Vec<u8>>> {
|
||||
// Fetch the previously registered password file from the DB.
|
||||
Ok(model::User::find_by_id(user_id)
|
||||
.select_only()
|
||||
.column(UserColumn::PasswordHash)
|
||||
.into_tuple::<(Option<Vec<u8>>,)>()
|
||||
.one(&self.sql_pool)
|
||||
.await?
|
||||
.and_then(|u| u.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LoginHandler for SqlBackendHandler {
|
||||
#[instrument(skip_all, level = "debug", err)]
|
||||
async fn bind(&self, request: BindRequest) -> Result<()> {
|
||||
if let Some(password_hash) = self
|
||||
.get_password_file_for_user(request.name.clone())
|
||||
.await?
|
||||
{
|
||||
info!(r#"Login attempt for "{}""#, &request.name);
|
||||
if passwords_match(
|
||||
&password_hash,
|
||||
&request.password,
|
||||
&self.opaque_setup,
|
||||
&request.name,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
r#"User "{}" doesn't exist or has no password"#,
|
||||
&request.name
|
||||
);
|
||||
}
|
||||
Err(DomainError::AuthenticationError(format!(
|
||||
r#"for user "{}""#,
|
||||
request.name
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl OpaqueHandler for SqlOpaqueHandler {
|
||||
#[instrument(skip_all, level = "debug", err)]
|
||||
async fn login_start(
|
||||
&self,
|
||||
request: login::ClientLoginStartRequest,
|
||||
) -> Result<login::ServerLoginStartResponse> {
|
||||
let user_id = request.username;
|
||||
info!(r#"OPAQUE login attempt for "{}""#, &user_id);
|
||||
let maybe_password_file = self
|
||||
.get_password_file_for_user(user_id.clone())
|
||||
.await?
|
||||
.map(|bytes| {
|
||||
opaque::server::ServerRegistration::deserialize(&bytes).map_err(|_| {
|
||||
DomainError::InternalError(format!("Corrupted password file for {}", &user_id))
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
// Get the CredentialResponse for the user, or a dummy one if no user/no password.
|
||||
let start_response = opaque::server::login::start_login(
|
||||
&mut rng,
|
||||
&self.opaque_setup,
|
||||
maybe_password_file,
|
||||
request.login_start_request,
|
||||
&user_id,
|
||||
)?;
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let server_data = login::ServerData {
|
||||
username: user_id,
|
||||
server_login: start_response.state,
|
||||
};
|
||||
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
|
||||
|
||||
Ok(login::ServerLoginStartResponse {
|
||||
server_data: base64::engine::general_purpose::STANDARD.encode(encrypted_state),
|
||||
credential_response: start_response.message,
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = "debug", err)]
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let login::ServerData {
|
||||
username,
|
||||
server_login,
|
||||
} = bincode::deserialize(&orion::aead::open(
|
||||
&secret_key,
|
||||
&base64::engine::general_purpose::STANDARD.decode(&request.server_data)?,
|
||||
)?)?;
|
||||
// Finish the login: this makes sure the client data is correct, and gives a session key we
|
||||
// don't need.
|
||||
match opaque::server::login::finish_login(server_login, request.credential_finalization) {
|
||||
Ok(session) => {
|
||||
info!(r#"OPAQUE login successful for "{}""#, &username);
|
||||
let _ = session.session_key;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(r#"OPAQUE login attempt failed for "{}""#, &username);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = "debug", err)]
|
||||
async fn registration_start(
|
||||
&self,
|
||||
request: registration::ClientRegistrationStartRequest,
|
||||
) -> Result<registration::ServerRegistrationStartResponse> {
|
||||
// Generate the server-side key and derive the data to send back.
|
||||
let start_response = opaque::server::registration::start_registration(
|
||||
&self.opaque_setup,
|
||||
request.registration_start_request,
|
||||
&request.username,
|
||||
)?;
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let server_data = registration::ServerData {
|
||||
username: request.username,
|
||||
};
|
||||
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
|
||||
Ok(registration::ServerRegistrationStartResponse {
|
||||
server_data: base64::engine::general_purpose::STANDARD.encode(encrypted_state),
|
||||
registration_response: start_response.message,
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = "debug", err)]
|
||||
async fn registration_finish(
|
||||
&self,
|
||||
request: registration::ClientRegistrationFinishRequest,
|
||||
) -> Result<()> {
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let registration::ServerData { username } = bincode::deserialize(&orion::aead::open(
|
||||
&secret_key,
|
||||
&base64::engine::general_purpose::STANDARD.decode(&request.server_data)?,
|
||||
)?)?;
|
||||
|
||||
let password_file =
|
||||
opaque::server::registration::get_password_file(request.registration_upload);
|
||||
// Set the user password to the new password.
|
||||
let user_update = model::users::ActiveModel {
|
||||
user_id: ActiveValue::Set(username.clone()),
|
||||
password_hash: ActiveValue::Set(Some(password_file.serialize())),
|
||||
..Default::default()
|
||||
};
|
||||
user_update.update(&self.sql_pool).await?;
|
||||
info!(r#"Successfully (re)set password for "{}""#, &username);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function to set a user's password.
|
||||
#[instrument(skip_all, level = "debug", err, fields(username = %username.as_str()))]
|
||||
pub async fn register_password(
|
||||
opaque_handler: &SqlOpaqueHandler,
|
||||
username: UserId,
|
||||
password: &SecUtf8,
|
||||
) -> Result<()> {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
use registration::*;
|
||||
let registration_start =
|
||||
opaque::client::registration::start_registration(password.unsecure().as_bytes(), &mut rng)?;
|
||||
let start_response = opaque_handler
|
||||
.registration_start(ClientRegistrationStartRequest {
|
||||
username,
|
||||
registration_start_request: registration_start.message,
|
||||
})
|
||||
.await?;
|
||||
let registration_finish = opaque::client::registration::finish_registration(
|
||||
registration_start.state,
|
||||
start_response.registration_response,
|
||||
&mut rng,
|
||||
)?;
|
||||
opaque_handler
|
||||
.registration_finish(ClientRegistrationFinishRequest {
|
||||
server_data: start_response.server_data,
|
||||
registration_upload: registration_finish.message,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use self::opaque::server::generate_random_private_key;
|
||||
|
||||
use super::*;
|
||||
use crate::sql_backend_handler::tests::{
|
||||
get_initialized_db, insert_user, insert_user_no_password,
|
||||
};
|
||||
|
||||
async fn attempt_login(
|
||||
opaque_handler: &SqlOpaqueHandler,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<()> {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
use login::*;
|
||||
let login_start = opaque::client::login::start_login(password, &mut rng)?;
|
||||
let start_response = opaque_handler
|
||||
.login_start(ClientLoginStartRequest {
|
||||
username: UserId::new(username),
|
||||
login_start_request: login_start.message,
|
||||
})
|
||||
.await?;
|
||||
let login_finish = opaque::client::login::finish_login(
|
||||
login_start.state,
|
||||
start_response.credential_response,
|
||||
)?;
|
||||
opaque_handler
|
||||
.login_finish(ClientLoginFinishRequest {
|
||||
server_data: start_response.server_data,
|
||||
credential_finalization: login_finish.message,
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_opaque_flow() -> Result<()> {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
crate::logging::init_for_tests();
|
||||
let backend_handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool);
|
||||
insert_user_no_password(&backend_handler, "bob").await;
|
||||
insert_user_no_password(&backend_handler, "john").await;
|
||||
attempt_login(&backend_handler, "bob", "bob00")
|
||||
.await
|
||||
.unwrap_err();
|
||||
register_password(
|
||||
&backend_handler,
|
||||
UserId::new("bob"),
|
||||
&secstr::SecUtf8::from("bob00"),
|
||||
)
|
||||
.await?;
|
||||
attempt_login(&backend_handler, "bob", "wrong_password")
|
||||
.await
|
||||
.unwrap_err();
|
||||
attempt_login(&backend_handler, "bob", "bob00").await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind_user() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let handler = SqlOpaqueHandler::new(generate_random_private_key(), sql_pool.clone());
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: UserId::new("bob"),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: UserId::new("andrew"),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: UserId::new("bob"),
|
||||
password: "wrong_password".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_no_password() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool.clone());
|
||||
insert_user_no_password(&handler, "bob").await;
|
||||
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: UserId::new("bob"),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,417 @@
|
||||
use crate::sql_backend_handler::SqlBackendHandler;
|
||||
use async_trait::async_trait;
|
||||
use lldap_domain::{
|
||||
requests::CreateAttributeRequest,
|
||||
schema::{AttributeList, AttributeSchema, Schema},
|
||||
types::{AttributeName, LdapObjectClass},
|
||||
};
|
||||
use lldap_domain_handlers::handler::{ReadSchemaBackendHandler, SchemaBackendHandler};
|
||||
use lldap_domain_model::{
|
||||
error::{DomainError, Result},
|
||||
model,
|
||||
};
|
||||
use sea_orm::{
|
||||
ActiveModelTrait, DatabaseTransaction, EntityTrait, QueryOrder, Set, TransactionTrait,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
impl ReadSchemaBackendHandler for SqlBackendHandler {
|
||||
async fn get_schema(&self) -> Result<Schema> {
|
||||
Ok(self
|
||||
.sql_pool
|
||||
.transaction::<_, Schema, DomainError>(|transaction| {
|
||||
Box::pin(async move { Self::get_schema_with_transaction(transaction).await })
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SchemaBackendHandler for SqlBackendHandler {
|
||||
async fn add_user_attribute(&self, request: CreateAttributeRequest) -> Result<()> {
|
||||
let new_attribute = model::user_attribute_schema::ActiveModel {
|
||||
attribute_name: Set(request.name),
|
||||
attribute_type: Set(request.attribute_type),
|
||||
is_list: Set(request.is_list),
|
||||
is_user_visible: Set(request.is_visible),
|
||||
is_user_editable: Set(request.is_editable),
|
||||
is_hardcoded: Set(false),
|
||||
};
|
||||
new_attribute.insert(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_group_attribute(&self, request: CreateAttributeRequest) -> Result<()> {
|
||||
let new_attribute = model::group_attribute_schema::ActiveModel {
|
||||
attribute_name: Set(request.name),
|
||||
attribute_type: Set(request.attribute_type),
|
||||
is_list: Set(request.is_list),
|
||||
is_group_visible: Set(request.is_visible),
|
||||
is_group_editable: Set(request.is_editable),
|
||||
is_hardcoded: Set(false),
|
||||
};
|
||||
new_attribute.insert(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user_attribute(&self, name: &AttributeName) -> Result<()> {
|
||||
model::UserAttributeSchema::delete_by_id(name.clone())
|
||||
.exec(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_group_attribute(&self, name: &AttributeName) -> Result<()> {
|
||||
model::GroupAttributeSchema::delete_by_id(name.clone())
|
||||
.exec(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_user_object_class(&self, name: &LdapObjectClass) -> Result<()> {
|
||||
let mut name_key = name.to_string();
|
||||
name_key.make_ascii_lowercase();
|
||||
model::user_object_classes::ActiveModel {
|
||||
lower_object_class: Set(name_key),
|
||||
object_class: Set(name.clone()),
|
||||
}
|
||||
.insert(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_group_object_class(&self, name: &LdapObjectClass) -> Result<()> {
|
||||
let mut name_key = name.to_string();
|
||||
name_key.make_ascii_lowercase();
|
||||
model::group_object_classes::ActiveModel {
|
||||
lower_object_class: Set(name_key),
|
||||
object_class: Set(name.clone()),
|
||||
}
|
||||
.insert(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user_object_class(&self, name: &LdapObjectClass) -> Result<()> {
|
||||
model::UserObjectClasses::delete_by_id(name.as_str().to_ascii_lowercase())
|
||||
.exec(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_group_object_class(&self, name: &LdapObjectClass) -> Result<()> {
|
||||
model::GroupObjectClasses::delete_by_id(name.as_str().to_ascii_lowercase())
|
||||
.exec(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SqlBackendHandler {
|
||||
pub(crate) async fn get_schema_with_transaction(
|
||||
transaction: &DatabaseTransaction,
|
||||
) -> Result<Schema> {
|
||||
Ok(Schema {
|
||||
user_attributes: AttributeList {
|
||||
attributes: Self::get_user_attributes(transaction).await?,
|
||||
},
|
||||
group_attributes: AttributeList {
|
||||
attributes: Self::get_group_attributes(transaction).await?,
|
||||
},
|
||||
extra_user_object_classes: Self::get_user_object_classes(transaction).await?,
|
||||
extra_group_object_classes: Self::get_group_object_classes(transaction).await?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_user_attributes(
|
||||
transaction: &DatabaseTransaction,
|
||||
) -> Result<Vec<AttributeSchema>> {
|
||||
Ok(model::UserAttributeSchema::find()
|
||||
.order_by_asc(model::UserAttributeSchemaColumn::AttributeName)
|
||||
.all(transaction)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|m| m.into())
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_group_attributes(
|
||||
transaction: &DatabaseTransaction,
|
||||
) -> Result<Vec<AttributeSchema>> {
|
||||
Ok(model::GroupAttributeSchema::find()
|
||||
.order_by_asc(model::GroupAttributeSchemaColumn::AttributeName)
|
||||
.all(transaction)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|m| m.into())
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_user_object_classes(
|
||||
transaction: &DatabaseTransaction,
|
||||
) -> Result<Vec<LdapObjectClass>> {
|
||||
Ok(model::UserObjectClasses::find()
|
||||
.order_by_asc(model::UserObjectClassesColumn::ObjectClass)
|
||||
.all(transaction)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_group_object_classes(
|
||||
transaction: &DatabaseTransaction,
|
||||
) -> Result<Vec<LdapObjectClass>> {
|
||||
Ok(model::GroupObjectClasses::find()
|
||||
.order_by_asc(model::GroupObjectClassesColumn::ObjectClass)
|
||||
.all(transaction)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sql_backend_handler::tests::*;
|
||||
use lldap_domain::requests::UpdateUserRequest;
|
||||
use lldap_domain::schema::AttributeList;
|
||||
use lldap_domain::types::{Attribute, AttributeType};
|
||||
use lldap_domain_handlers::handler::{UserBackendHandler, UserRequestFilter};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_schema() {
|
||||
let fixture = TestFixture::new().await;
|
||||
assert_eq!(
|
||||
fixture.handler.get_schema().await.unwrap(),
|
||||
Schema {
|
||||
user_attributes: AttributeList {
|
||||
attributes: vec![
|
||||
AttributeSchema {
|
||||
name: "avatar".into(),
|
||||
attribute_type: AttributeType::JpegPhoto,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: true,
|
||||
is_hardcoded: true,
|
||||
is_readonly: false,
|
||||
},
|
||||
AttributeSchema {
|
||||
name: "first_name".into(),
|
||||
attribute_type: AttributeType::String,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: true,
|
||||
is_hardcoded: true,
|
||||
is_readonly: false,
|
||||
},
|
||||
AttributeSchema {
|
||||
name: "last_name".into(),
|
||||
attribute_type: AttributeType::String,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: true,
|
||||
is_hardcoded: true,
|
||||
is_readonly: false,
|
||||
}
|
||||
]
|
||||
},
|
||||
group_attributes: AttributeList {
|
||||
attributes: Vec::new()
|
||||
},
|
||||
extra_user_object_classes: Vec::new(),
|
||||
extra_group_object_classes: Vec::new(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_attribute_add_and_delete() {
|
||||
let fixture = TestFixture::new().await;
|
||||
let new_attribute = CreateAttributeRequest {
|
||||
name: "new_attribute".into(),
|
||||
attribute_type: AttributeType::Integer,
|
||||
is_list: true,
|
||||
is_visible: false,
|
||||
is_editable: false,
|
||||
};
|
||||
fixture
|
||||
.handler
|
||||
.add_user_attribute(new_attribute)
|
||||
.await
|
||||
.unwrap();
|
||||
let expected_value = AttributeSchema {
|
||||
name: "new_attribute".into(),
|
||||
attribute_type: AttributeType::Integer,
|
||||
is_list: true,
|
||||
is_visible: false,
|
||||
is_editable: false,
|
||||
is_hardcoded: false,
|
||||
is_readonly: false,
|
||||
};
|
||||
assert!(
|
||||
fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.user_attributes
|
||||
.attributes
|
||||
.contains(&expected_value)
|
||||
);
|
||||
fixture
|
||||
.handler
|
||||
.delete_user_attribute(&"new_attribute".into())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
!fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.user_attributes
|
||||
.attributes
|
||||
.contains(&expected_value)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_attribute_present_filter() {
|
||||
let fixture = TestFixture::new().await;
|
||||
let new_attribute = CreateAttributeRequest {
|
||||
name: "new_attribute".into(),
|
||||
attribute_type: AttributeType::Integer,
|
||||
is_list: true,
|
||||
is_visible: false,
|
||||
is_editable: false,
|
||||
};
|
||||
fixture
|
||||
.handler
|
||||
.add_user_attribute(new_attribute)
|
||||
.await
|
||||
.unwrap();
|
||||
fixture
|
||||
.handler
|
||||
.update_user(UpdateUserRequest {
|
||||
user_id: "bob".into(),
|
||||
insert_attributes: vec![Attribute {
|
||||
name: "new_attribute".into(),
|
||||
value: vec![3].into(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let users = get_user_names(
|
||||
&fixture.handler,
|
||||
Some(UserRequestFilter::CustomAttributePresent(
|
||||
"new_attribute".into(),
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(users, vec!["bob"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_group_attribute_add_and_delete() {
|
||||
let fixture = TestFixture::new().await;
|
||||
let new_attribute = CreateAttributeRequest {
|
||||
name: "NeW_aTTribute".into(),
|
||||
attribute_type: AttributeType::JpegPhoto,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: false,
|
||||
};
|
||||
fixture
|
||||
.handler
|
||||
.add_group_attribute(new_attribute)
|
||||
.await
|
||||
.unwrap();
|
||||
let expected_value = AttributeSchema {
|
||||
name: "new_attribute".into(),
|
||||
attribute_type: AttributeType::JpegPhoto,
|
||||
is_list: false,
|
||||
is_visible: true,
|
||||
is_editable: false,
|
||||
is_hardcoded: false,
|
||||
is_readonly: false,
|
||||
};
|
||||
assert!(
|
||||
fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.group_attributes
|
||||
.attributes
|
||||
.contains(&expected_value)
|
||||
);
|
||||
fixture
|
||||
.handler
|
||||
.delete_group_attribute(&"new_attriBUte".into())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
!fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.group_attributes
|
||||
.attributes
|
||||
.contains(&expected_value)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_object_class_add_and_delete() {
|
||||
let fixture = TestFixture::new().await;
|
||||
let new_object_class = LdapObjectClass::new("newObjectClass");
|
||||
fixture
|
||||
.handler
|
||||
.add_user_object_class(&new_object_class)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.extra_user_object_classes,
|
||||
vec![new_object_class.clone()]
|
||||
);
|
||||
fixture
|
||||
.handler
|
||||
.add_user_object_class(&LdapObjectClass::new("newobjEctclass"))
|
||||
.await
|
||||
.expect_err("Should not be able to add the same object class twice");
|
||||
assert_eq!(
|
||||
fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.extra_user_object_classes,
|
||||
vec![new_object_class.clone()]
|
||||
);
|
||||
fixture
|
||||
.handler
|
||||
.delete_user_object_class(&new_object_class)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
fixture
|
||||
.handler
|
||||
.get_schema()
|
||||
.await
|
||||
.unwrap()
|
||||
.extra_user_object_classes
|
||||
.is_empty()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,566 @@
|
||||
use crate::sql_migrations::{Metadata, get_schema_version, migrate_from_version, upgrade_to_v1};
|
||||
use sea_orm::{
|
||||
ConnectionTrait, DeriveValueType, Iden, QueryResult, TryGetable, Value, sea_query::Query,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub type DbConnection = sea_orm::DatabaseConnection;
|
||||
|
||||
#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord, DeriveValueType)]
|
||||
pub struct SchemaVersion(pub i16);
|
||||
|
||||
pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(10);
|
||||
|
||||
#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord)]
|
||||
pub struct PrivateKeyHash(pub [u8; 32]);
|
||||
|
||||
impl TryGetable for PrivateKeyHash {
|
||||
fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, sea_orm::TryGetError> {
|
||||
let index = format!("{pre}{col}");
|
||||
Self::try_get_by(res, index.as_str())
|
||||
}
|
||||
|
||||
fn try_get_by_index(res: &QueryResult, index: usize) -> Result<Self, sea_orm::TryGetError> {
|
||||
Self::try_get_by(res, index)
|
||||
}
|
||||
|
||||
fn try_get_by<I: sea_orm::ColIdx>(
|
||||
res: &QueryResult,
|
||||
index: I,
|
||||
) -> Result<Self, sea_orm::TryGetError> {
|
||||
Ok(PrivateKeyHash(
|
||||
std::convert::TryInto::<[u8; 32]>::try_into(res.try_get_by::<Vec<u8>, I>(index)?)
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PrivateKeyHash> for Value {
|
||||
fn from(val: PrivateKeyHash) -> Self {
|
||||
Self::from(val.0.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> {
|
||||
let version = {
|
||||
if let Some(version) = get_schema_version(pool).await {
|
||||
version
|
||||
} else {
|
||||
upgrade_to_v1(pool).await?;
|
||||
SchemaVersion(1)
|
||||
}
|
||||
};
|
||||
migrate_from_version(pool, version, LAST_SCHEMA_VERSION).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ConfigLocation {
|
||||
ConfigFile(String),
|
||||
EnvironmentVariable(String),
|
||||
CommandLine,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PrivateKeyLocation {
|
||||
KeySeed(ConfigLocation),
|
||||
KeyFile(ConfigLocation, std::ffi::OsString),
|
||||
Default,
|
||||
Tests,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PrivateKeyInfo {
|
||||
pub private_key_hash: PrivateKeyHash,
|
||||
pub private_key_location: PrivateKeyLocation,
|
||||
}
|
||||
|
||||
pub async fn get_private_key_info(pool: &DbConnection) -> anyhow::Result<Option<PrivateKeyInfo>> {
|
||||
let result = pool
|
||||
.query_one(
|
||||
pool.get_database_backend().build(
|
||||
Query::select()
|
||||
.column(Metadata::PrivateKeyHash)
|
||||
.column(Metadata::PrivateKeyLocation)
|
||||
.from(Metadata::Table),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
let result = match result {
|
||||
None => return Ok(None),
|
||||
Some(r) => r,
|
||||
};
|
||||
if let Ok(hash) = result.try_get("", &Metadata::PrivateKeyHash.to_string()) {
|
||||
Ok(Some(PrivateKeyInfo {
|
||||
private_key_hash: hash,
|
||||
private_key_location: serde_json::from_str(
|
||||
&result.try_get::<String>("", &Metadata::PrivateKeyLocation.to_string())?,
|
||||
)?,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_private_key_info(pool: &DbConnection, info: PrivateKeyInfo) -> anyhow::Result<()> {
|
||||
pool.execute(
|
||||
pool.get_database_backend().build(
|
||||
Query::update()
|
||||
.table(Metadata::Table)
|
||||
.value(Metadata::PrivateKeyHash, Value::from(info.private_key_hash))
|
||||
.value(
|
||||
Metadata::PrivateKeyLocation,
|
||||
Value::from(serde_json::to_string(&info.private_key_location).unwrap()),
|
||||
),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::sql_migrations;
|
||||
use lldap_domain::types::{GroupId, JpegPhoto, Serialized, Uuid};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
use chrono::prelude::*;
|
||||
use sea_orm::{ConnectionTrait, Database, DbBackend, FromQueryResult};
|
||||
use tracing::error;
|
||||
|
||||
async fn get_in_memory_db() -> DbConnection {
|
||||
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
|
||||
sql_opt.max_connections(1).sqlx_logging(false);
|
||||
Database::connect(sql_opt).await.unwrap()
|
||||
}
|
||||
|
||||
fn raw_statement(sql: &str) -> sea_orm::Statement {
|
||||
sea_orm::Statement::from_string(DbBackend::Sqlite, sql.to_owned())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_init_table() {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users
|
||||
(user_id, email, lowercase_email, display_name, creation_date, password_hash, uuid)
|
||||
VALUES ("bôb", "böb@bob.bob", "böb@bob.bob", "Bob Bobbersön", "1970-01-01 00:00:00", "bob00", "abc")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO user_attributes
|
||||
(user_attribute_user_id, user_attribute_name, user_attribute_value)
|
||||
VALUES ("bôb", "first_name", "Bob")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
struct ShortUserDetails {
|
||||
display_name: String,
|
||||
creation_date: chrono::NaiveDateTime,
|
||||
}
|
||||
let result = ShortUserDetails::find_by_statement(raw_statement(
|
||||
r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#,
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
ShortUserDetails {
|
||||
display_name: "Bob Bobbersön".to_owned(),
|
||||
creation_date: Utc.timestamp_opt(0, 0).unwrap().naive_utc(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_already_init_table() {
|
||||
crate::logging::init_for_tests();
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migrate_tables() {
|
||||
crate::logging::init_for_tests();
|
||||
// Test that we add the column creation_date to groups and uuid to users and groups.
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"CREATE TABLE users ( user_id TEXT PRIMARY KEY, display_name TEXT, first_name TEXT NOT NULL, last_name TEXT, avatar BLOB, creation_date TEXT, email TEXT);"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users (user_id, display_name, first_name, creation_date, email)
|
||||
VALUES ("bôb", "", "", "1970-01-01 00:00:00", "bob@bob.com")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users (user_id, display_name, first_name, creation_date, email)
|
||||
VALUES ("john", "John Doe", "John", "1971-01-01 00:00:00", "bob2@bob.com")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO groups (display_name)
|
||||
VALUES ("lldap_admin"), ("lldap_readonly")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO groups (display_name, creation_date, uuid)
|
||||
VALUES ("test", "1970-01-01 00:00:00", "abc")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
struct SimpleUser {
|
||||
display_name: Option<String>,
|
||||
uuid: Uuid,
|
||||
}
|
||||
assert_eq!(
|
||||
SimpleUser::find_by_statement(raw_statement(
|
||||
r#"SELECT display_name, uuid FROM users ORDER BY display_name"#
|
||||
))
|
||||
.all(&sql_pool)
|
||||
.await
|
||||
.unwrap(),
|
||||
vec![
|
||||
SimpleUser {
|
||||
display_name: None,
|
||||
uuid: lldap_domain::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")
|
||||
},
|
||||
SimpleUser {
|
||||
display_name: Some("John Doe".to_owned()),
|
||||
uuid: lldap_domain::uuid!("986765a5-3f03-389e-b47b-536b2d6e1bec")
|
||||
}
|
||||
]
|
||||
);
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
struct UserAttribute {
|
||||
user_attribute_user_id: String,
|
||||
user_attribute_name: String,
|
||||
user_attribute_value: Serialized,
|
||||
}
|
||||
assert_eq!(
|
||||
UserAttribute::find_by_statement(raw_statement(
|
||||
r#"SELECT user_attribute_user_id, user_attribute_name, user_attribute_value FROM user_attributes ORDER BY user_attribute_user_id, user_attribute_value"#
|
||||
))
|
||||
.all(&sql_pool)
|
||||
.await
|
||||
.unwrap(),
|
||||
vec![
|
||||
UserAttribute {
|
||||
user_attribute_user_id: "john".to_owned(),
|
||||
user_attribute_name: "first_name".to_owned(),
|
||||
user_attribute_value: Serialized::from("John"),
|
||||
}
|
||||
]
|
||||
);
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
struct ShortGroupDetails {
|
||||
group_id: GroupId,
|
||||
display_name: String,
|
||||
}
|
||||
assert_eq!(
|
||||
ShortGroupDetails::find_by_statement(raw_statement(
|
||||
r#"SELECT group_id, display_name, creation_date FROM groups"#
|
||||
))
|
||||
.all(&sql_pool)
|
||||
.await
|
||||
.unwrap(),
|
||||
vec![
|
||||
ShortGroupDetails {
|
||||
group_id: GroupId(1),
|
||||
display_name: "lldap_admin".to_string()
|
||||
},
|
||||
ShortGroupDetails {
|
||||
group_id: GroupId(2),
|
||||
display_name: "lldap_password_manager".to_string()
|
||||
},
|
||||
ShortGroupDetails {
|
||||
group_id: GroupId(3),
|
||||
display_name: "test".to_string()
|
||||
}
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
|
||||
r#"SELECT version FROM metadata"#
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
sql_migrations::JustSchemaVersion {
|
||||
version: LAST_SCHEMA_VERSION
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_to_v4() {
|
||||
crate::logging::init_for_tests();
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
upgrade_to_v1(&sql_pool).await.unwrap();
|
||||
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(3))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users (user_id, email, display_name, first_name, creation_date, uuid)
|
||||
VALUES ("bob", "bob@bob.com", "", "", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users (user_id, email, display_name, first_name, creation_date, uuid)
|
||||
VALUES ("bob2", "bob@bob.com", "", "", "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
error!(
|
||||
"{}",
|
||||
migrate_from_version(&sql_pool, SchemaVersion(3), SchemaVersion(4))
|
||||
.await
|
||||
.expect_err("migration should fail")
|
||||
);
|
||||
assert_eq!(
|
||||
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
|
||||
r#"SELECT version FROM metadata"#
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
sql_migrations::JustSchemaVersion {
|
||||
version: SchemaVersion(3)
|
||||
}
|
||||
);
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"UPDATE users SET email = "new@bob.com" WHERE user_id = "bob2""#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
migrate_from_version(&sql_pool, SchemaVersion(3), SchemaVersion(4))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
|
||||
r#"SELECT version FROM metadata"#
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
sql_migrations::JustSchemaVersion {
|
||||
version: SchemaVersion(4)
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_to_v5() {
|
||||
crate::logging::init_for_tests();
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
upgrade_to_v1(&sql_pool).await.unwrap();
|
||||
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(4))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users (user_id, email, creation_date, uuid)
|
||||
VALUES ("bob", "bob@bob.com", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(sea_orm::Statement::from_sql_and_values(DbBackend::Sqlite,
|
||||
r#"INSERT INTO users (user_id, email, display_name, first_name, last_name, avatar, creation_date, uuid)
|
||||
VALUES ("bob2", "bob2@bob.com", "display bob", "first bob", "last bob", $1, "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#, [JpegPhoto::for_tests().into()]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
migrate_from_version(&sql_pool, SchemaVersion(4), SchemaVersion(5))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
|
||||
r#"SELECT version FROM metadata"#
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
sql_migrations::JustSchemaVersion {
|
||||
version: SchemaVersion(5)
|
||||
}
|
||||
);
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
pub struct UserV5 {
|
||||
user_id: String,
|
||||
email: String,
|
||||
display_name: Option<String>,
|
||||
}
|
||||
assert_eq!(
|
||||
UserV5::find_by_statement(raw_statement(
|
||||
r#"SELECT user_id, email, display_name FROM users ORDER BY user_id ASC"#
|
||||
))
|
||||
.all(&sql_pool)
|
||||
.await
|
||||
.unwrap(),
|
||||
vec![
|
||||
UserV5 {
|
||||
user_id: "bob".to_owned(),
|
||||
email: "bob@bob.com".to_owned(),
|
||||
display_name: None
|
||||
},
|
||||
UserV5 {
|
||||
user_id: "bob2".to_owned(),
|
||||
email: "bob2@bob.com".to_owned(),
|
||||
display_name: Some("display bob".to_owned())
|
||||
},
|
||||
]
|
||||
);
|
||||
sql_pool
|
||||
.execute(raw_statement(r#"SELECT first_name FROM users"#))
|
||||
.await
|
||||
.unwrap_err();
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
pub struct UserAttribute {
|
||||
user_attribute_user_id: String,
|
||||
user_attribute_name: String,
|
||||
user_attribute_value: Serialized,
|
||||
}
|
||||
assert_eq!(
|
||||
UserAttribute::find_by_statement(raw_statement(r#"SELECT * FROM user_attributes ORDER BY user_attribute_user_id, user_attribute_name ASC"#))
|
||||
.all(&sql_pool)
|
||||
.await
|
||||
.unwrap(),
|
||||
vec![
|
||||
UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "avatar".to_owned(), user_attribute_value: Serialized::from(&JpegPhoto::for_tests()) },
|
||||
UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "first_name".to_owned(), user_attribute_value: Serialized::from("first bob") },
|
||||
UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "last_name".to_owned(), user_attribute_value: Serialized::from("last bob") },
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_to_v6() {
|
||||
crate::logging::init_for_tests();
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
upgrade_to_v1(&sql_pool).await.unwrap();
|
||||
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(5))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO users (user_id, email, display_name, creation_date, uuid)
|
||||
VALUES ("bob", "BOb@bob.com", "", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO groups (display_name, creation_date, uuid)
|
||||
VALUES ("BestGroup", "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
migrate_from_version(&sql_pool, SchemaVersion(5), SchemaVersion(6))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
|
||||
r#"SELECT version FROM metadata"#
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
sql_migrations::JustSchemaVersion {
|
||||
version: SchemaVersion(6)
|
||||
}
|
||||
);
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
struct ShortUserDetails {
|
||||
email: String,
|
||||
lowercase_email: String,
|
||||
}
|
||||
let result = ShortUserDetails::find_by_statement(raw_statement(
|
||||
r#"SELECT email, lowercase_email FROM users WHERE user_id = "bob""#,
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
ShortUserDetails {
|
||||
email: "BOb@bob.com".to_owned(),
|
||||
lowercase_email: "bob@bob.com".to_owned(),
|
||||
}
|
||||
);
|
||||
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
|
||||
struct ShortGroupDetails {
|
||||
display_name: String,
|
||||
lowercase_display_name: String,
|
||||
}
|
||||
let result = ShortGroupDetails::find_by_statement(raw_statement(
|
||||
r#"SELECT display_name, lowercase_display_name FROM groups"#,
|
||||
))
|
||||
.one(&sql_pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
ShortGroupDetails {
|
||||
display_name: "BestGroup".to_owned(),
|
||||
lowercase_display_name: "bestgroup".to_owned(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_too_high_version() {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"CREATE TABLE metadata ( version INTEGER);"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
sql_pool
|
||||
.execute(raw_statement(
|
||||
r#"INSERT INTO metadata (version)
|
||||
VALUES (127)"#,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(init_table(&sql_pool).await.is_err());
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user