server: extract the sql backend handler to a separate crate

This commit is contained in:
Valentin Tolmer
2025-04-04 23:43:25 -05:00
committed by nitnelave
parent ee21d83056
commit 55de3ac329
27 changed files with 276 additions and 156 deletions
+6
View File
@@ -132,6 +132,12 @@ pub mod server {
pub use super::*;
pub type ServerRegistration = opaque_ke::ServerRegistration<DefaultSuite>;
pub type ServerSetup = opaque_ke::ServerSetup<DefaultSuite>;
pub fn generate_random_private_key() -> ServerSetup {
let mut rng = rand::rngs::OsRng;
ServerSetup::new(&mut rng)
}
/// Methods to register a new user, from the server side.
pub mod registration {
pub use super::*;
-1
View File
@@ -61,7 +61,6 @@ mod tests {
use lldap_domain_handlers::handler::{GroupRequestFilter, UserRequestFilter};
use lldap_test_utils::MockTestBackendHandler;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_compare_user() {
-1
View File
@@ -172,7 +172,6 @@ mod tests {
use lldap_test_utils::MockTestBackendHandler;
use mockall::predicate::eq;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_create_user() {
-1
View File
@@ -115,7 +115,6 @@ mod tests {
use lldap_test_utils::MockTestBackendHandler;
use mockall::predicate::eq;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_delete_user() {
-1
View File
@@ -140,7 +140,6 @@ mod tests {
use mockall::predicate::eq;
use pretty_assertions::assert_eq;
use std::collections::HashSet;
fn setup_target_user_groups(
mock: &mut MockTestBackendHandler,
-1
View File
@@ -333,7 +333,6 @@ mod tests {
use lldap_test_utils::MockTestBackendHandler;
use mockall::predicate::eq;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_search_root_dse() {
+87
View File
@@ -0,0 +1,87 @@
[package]
name = "lldap_sql_backend_handler"
version = "0.1.0"
description = "SQL backend for LLDAP"
authors.workspace = true
edition.workspace = true
homepage.workspace = true
license.workspace = true
repository.workspace = true
[features]
test = []
[dependencies]
anyhow = "*"
async-trait = "0.1"
base64 = "0.21"
bincode = "1.3"
itertools = "0.10"
ldap3_proto = "0.6.0"
orion = "0.17"
serde_json = "1"
tracing = "*"
[dependencies.chrono]
features = ["serde"]
version = "*"
[dependencies.rand]
features = ["small_rng", "getrandom"]
version = "0.8"
[dependencies.sea-orm]
workspace = true
features = [
"macros",
"with-chrono",
"with-uuid",
"sqlx-all",
"runtime-actix-rustls",
]
[dependencies.secstr]
features = ["serde"]
version = "*"
[dependencies.serde]
workspace = true
[dependencies.uuid]
version = "1"
features = ["v1", "v3"]
[dependencies.lldap_access_control]
path = "../access-control"
[dependencies.lldap_auth]
path = "../auth"
features = ["opaque_server", "opaque_client", "sea_orm"]
[dependencies.lldap_domain]
path = "../domain"
[dependencies.lldap_domain_handlers]
path = "../domain-handlers"
[dependencies.lldap_domain_model]
path = "../domain-model"
[dependencies.lldap_opaque_handler]
path = "../opaque-handler"
[dev-dependencies.lldap_test_utils]
path = "../test-utils"
[dev-dependencies]
log = "*"
mockall = "0.11.4"
pretty_assertions = "1"
[dev-dependencies.tokio]
features = ["full"]
version = "1.25"
[dev-dependencies.tracing-subscriber]
version = "0.3"
features = ["env-filter", "tracing-log"]
+11
View File
@@ -0,0 +1,11 @@
pub(crate) mod logging;
pub(crate) mod sql_backend_handler;
pub(crate) mod sql_group_backend_handler;
pub(crate) mod sql_opaque_handler;
pub(crate) mod sql_schema_backend_handler;
pub(crate) mod sql_user_backend_handler;
pub use sql_backend_handler::SqlBackendHandler;
pub use sql_opaque_handler::register_password;
pub mod sql_migrations;
pub mod sql_tables;
+10
View File
@@ -0,0 +1,10 @@
#[cfg(test)]
pub fn init_for_tests() {
if let Err(e) = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::DEBUG)
.with_test_writer()
.try_init()
{
log::warn!("Could not set up test logging: {:#}", e);
}
}
@@ -0,0 +1,186 @@
use crate::sql_tables::DbConnection;
use async_trait::async_trait;
use lldap_auth::opaque::server::ServerSetup;
use lldap_domain_handlers::handler::BackendHandler;
#[derive(Clone)]
pub struct SqlBackendHandler {
pub(crate) opaque_setup: ServerSetup,
pub(crate) sql_pool: DbConnection,
}
impl SqlBackendHandler {
pub fn new(opaque_setup: ServerSetup, sql_pool: DbConnection) -> Self {
SqlBackendHandler {
opaque_setup,
sql_pool,
}
}
pub fn pool(&self) -> &DbConnection {
&self.sql_pool
}
}
#[async_trait]
impl BackendHandler for SqlBackendHandler {}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::sql_tables::init_table;
use lldap_auth::{
opaque::{self, server::generate_random_private_key},
registration,
};
use lldap_domain::{
requests::{CreateGroupRequest, CreateUserRequest},
types::{Attribute as DomainAttribute, GroupId, UserId},
};
use lldap_domain_handlers::handler::{
GroupBackendHandler, UserBackendHandler, UserListerBackendHandler, UserRequestFilter,
};
use pretty_assertions::assert_eq;
use sea_orm::Database;
pub async fn get_in_memory_db() -> DbConnection {
crate::logging::init_for_tests();
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
sql_opt
.max_connections(1)
.sqlx_logging(true)
.sqlx_logging_level(log::LevelFilter::Debug);
Database::connect(sql_opt).await.unwrap()
}
pub async fn get_initialized_db() -> DbConnection {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
}
pub async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) {
use lldap_opaque_handler::OpaqueHandler;
insert_user_no_password(handler, name).await;
let mut rng = rand::rngs::OsRng;
let client_registration_start =
opaque::client::registration::start_registration(pass.as_bytes(), &mut rng).unwrap();
let response = handler
.registration_start(registration::ClientRegistrationStartRequest {
username: name.into(),
registration_start_request: client_registration_start.message,
})
.await
.unwrap();
let registration_upload = opaque::client::registration::finish_registration(
client_registration_start.state,
response.registration_response,
&mut rng,
)
.unwrap();
handler
.registration_finish(registration::ClientRegistrationFinishRequest {
server_data: response.server_data,
registration_upload: registration_upload.message,
})
.await
.unwrap();
}
pub async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler
.create_user(CreateUserRequest {
user_id: UserId::new(name),
email: format!("{}@bob.bob", name).into(),
display_name: Some("display ".to_string() + name),
attributes: vec![
DomainAttribute {
name: "first_name".into(),
value: ("first ".to_string() + name).into(),
},
DomainAttribute {
name: "last_name".into(),
value: ("last ".to_string() + name).into(),
},
],
})
.await
.unwrap();
}
pub async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId {
handler
.create_group(CreateGroupRequest {
display_name: name.into(),
..Default::default()
})
.await
.unwrap()
}
pub async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
handler
.add_user_to_group(&UserId::new(user_id), group_id)
.await
.unwrap();
}
pub async fn get_user_names(
handler: &SqlBackendHandler,
filters: Option<UserRequestFilter>,
) -> Vec<String> {
handler
.list_users(filters, false)
.await
.unwrap()
.into_iter()
.map(|u| u.user.user_id.to_string())
.collect::<Vec<_>>()
}
pub struct TestFixture {
pub handler: SqlBackendHandler,
pub groups: Vec<GroupId>,
}
impl TestFixture {
pub async fn new() -> Self {
let sql_pool = get_initialized_db().await;
let handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool);
insert_user_no_password(&handler, "bob").await;
insert_user_no_password(&handler, "patrick").await;
insert_user_no_password(&handler, "John").await;
insert_user_no_password(&handler, "NoGroup").await;
let mut groups = vec![];
groups.push(insert_group(&handler, "Best Group").await);
groups.push(insert_group(&handler, "Worst Group").await);
groups.push(insert_group(&handler, "Empty Group").await);
insert_membership(&handler, groups[0], "bob").await;
insert_membership(&handler, groups[0], "patrick").await;
insert_membership(&handler, groups[1], "patrick").await;
insert_membership(&handler, groups[1], "John").await;
Self { handler, groups }
}
}
#[tokio::test]
async fn test_sql_injection() {
let sql_pool = get_initialized_db().await;
let handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool);
let user_name = UserId::new(r#"bob"e"i'o;aü"#);
insert_user_no_password(&handler, user_name.as_str()).await;
{
let users = handler
.list_users(None, false)
.await
.unwrap()
.into_iter()
.map(|u| u.user.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec![user_name.clone()]);
let user = handler.get_user_details(&user_name).await.unwrap();
assert_eq!(user.user_id, user_name);
}
}
}
@@ -0,0 +1,657 @@
use crate::sql_backend_handler::SqlBackendHandler;
use async_trait::async_trait;
use lldap_access_control::UserReadableBackendHandler;
use lldap_domain::{
requests::{CreateGroupRequest, UpdateGroupRequest},
types::{AttributeName, Group, GroupDetails, GroupId, Serialized, Uuid},
};
use lldap_domain_handlers::handler::{
GroupBackendHandler, GroupListerBackendHandler, GroupRequestFilter,
};
use lldap_domain_model::{
error::{DomainError, Result},
model::{self, GroupColumn, MembershipColumn, deserialize},
};
use sea_orm::{
ActiveModelTrait, ColumnTrait, DatabaseTransaction, EntityTrait, QueryFilter, QueryOrder,
QuerySelect, QueryTrait, Set, TransactionTrait,
sea_query::{Alias, Cond, Expr, Func, IntoCondition, OnConflict, SimpleExpr},
};
use tracing::instrument;
fn attribute_condition(name: AttributeName, value: Option<Serialized>) -> Cond {
Expr::in_subquery(
Expr::col(GroupColumn::GroupId.as_column_ref()),
model::GroupAttributes::find()
.select_only()
.column(model::GroupAttributesColumn::GroupId)
.filter(model::GroupAttributesColumn::AttributeName.eq(name))
.filter(
value
.map(|value| model::GroupAttributesColumn::Value.eq(value))
.unwrap_or_else(|| SimpleExpr::Value(true.into())),
)
.into_query(),
)
.into_condition()
}
fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
use GroupRequestFilter::*;
let group_table = Alias::new("groups");
match filter {
And(fs) => {
if fs.is_empty() {
SimpleExpr::Value(true.into()).into_condition()
} else {
fs.into_iter()
.fold(Cond::all(), |c, f| c.add(get_group_filter_expr(f)))
}
}
Or(fs) => {
if fs.is_empty() {
SimpleExpr::Value(false.into()).into_condition()
} else {
fs.into_iter()
.fold(Cond::any(), |c, f| c.add(get_group_filter_expr(f)))
}
}
Not(f) => get_group_filter_expr(*f).not(),
DisplayName(name) => GroupColumn::LowercaseDisplayName
.eq(name.as_str().to_lowercase())
.into_condition(),
GroupId(id) => GroupColumn::GroupId.eq(id.0).into_condition(),
Uuid(uuid) => GroupColumn::Uuid.eq(uuid.to_string()).into_condition(),
// WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user))
Member(user) => GroupColumn::GroupId
.in_subquery(
model::Membership::find()
.select_only()
.column(MembershipColumn::GroupId)
.filter(MembershipColumn::UserId.eq(user))
.into_query(),
)
.into_condition(),
DisplayNameSubString(filter) => SimpleExpr::FunctionCall(Func::lower(Expr::col((
group_table,
GroupColumn::LowercaseDisplayName,
))))
.like(filter.to_sql_filter())
.into_condition(),
AttributeEquality(name, value) => attribute_condition(name, Some(value.into())),
CustomAttributePresent(name) => attribute_condition(name, None),
}
}
#[async_trait]
impl GroupListerBackendHandler for SqlBackendHandler {
#[instrument(skip(self), level = "debug", ret, err)]
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
let filters = filters
.map(|f| {
GroupColumn::GroupId
.in_subquery(
model::Group::find()
.find_also_linked(model::memberships::GroupToUser)
.select_only()
.column(GroupColumn::GroupId)
.filter(get_group_filter_expr(f))
.into_query(),
)
.into_condition()
})
.unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition());
let results = model::Group::find()
.order_by_asc(GroupColumn::GroupId)
.find_with_related(model::Membership)
.filter(filters.clone())
.all(&self.sql_pool)
.await?;
let mut groups: Vec<_> = results
.into_iter()
.map(|(group, users)| {
let users: Vec<_> = users.into_iter().map(|u| u.user_id).collect();
Group {
users,
..group.into()
}
})
.collect();
// TODO: should be wrapped in a transaction
let schema = self.get_schema().await?;
let attributes = model::GroupAttributes::find()
.filter(
model::GroupAttributesColumn::GroupId.in_subquery(
model::Group::find()
.filter(filters)
.select_only()
.column(model::groups::Column::GroupId)
.into_query(),
),
)
.order_by_asc(model::GroupAttributesColumn::GroupId)
.order_by_asc(model::GroupAttributesColumn::AttributeName)
.all(&self.sql_pool)
.await?;
let mut attributes_iter = attributes.into_iter().peekable();
use itertools::Itertools; // For take_while_ref
for group in groups.iter_mut() {
group.attributes = attributes_iter
.take_while_ref(|u| u.group_id == group.id)
.map(|a| {
deserialize::deserialize_attribute(
a.attribute_name,
&a.value,
&schema.get_schema().group_attributes,
)
})
.collect::<Result<Vec<_>>>()?;
}
groups.sort_by(|g1, g2| g1.display_name.cmp(&g2.display_name));
Ok(groups)
}
}
#[async_trait]
impl GroupBackendHandler for SqlBackendHandler {
#[instrument(skip(self), level = "debug", ret, err)]
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupDetails> {
let mut group_details = model::Group::find_by_id(group_id)
.one(&self.sql_pool)
.await?
.map(Into::<GroupDetails>::into)
.ok_or_else(|| DomainError::EntityNotFound(format!("{:?}", group_id)))?;
let attributes = model::GroupAttributes::find()
.filter(model::GroupAttributesColumn::GroupId.eq(group_details.group_id))
.order_by_asc(model::GroupAttributesColumn::AttributeName)
.all(&self.sql_pool)
.await?;
let schema = self.get_schema().await?;
group_details.attributes = attributes
.into_iter()
.map(|a| {
deserialize::deserialize_attribute(
a.attribute_name,
&a.value,
&schema.get_schema().group_attributes,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(group_details)
}
#[instrument(skip(self), level = "debug", err, fields(group_id = ?request.group_id))]
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()> {
Ok(self
.sql_pool
.transaction::<_, (), DomainError>(|transaction| {
Box::pin(
async move { Self::update_group_with_transaction(request, transaction).await },
)
})
.await?)
}
#[instrument(skip(self), level = "debug", ret, err)]
async fn create_group(&self, request: CreateGroupRequest) -> Result<GroupId> {
let now = chrono::Utc::now().naive_utc();
let uuid = Uuid::from_name_and_date(request.display_name.as_str(), &now);
let lower_display_name = request.display_name.as_str().to_lowercase();
let new_group = model::groups::ActiveModel {
display_name: Set(request.display_name),
lowercase_display_name: Set(lower_display_name),
creation_date: Set(now),
uuid: Set(uuid),
..Default::default()
};
Ok(self
.sql_pool
.transaction::<_, GroupId, DomainError>(|transaction| {
Box::pin(async move {
let schema = Self::get_schema_with_transaction(transaction).await?;
let group_id = new_group.insert(transaction).await?.group_id;
let mut new_group_attributes = Vec::new();
for attribute in request.attributes {
if schema
.group_attributes
.get_attribute_type(&attribute.name)
.is_some()
{
new_group_attributes.push(model::group_attributes::ActiveModel {
group_id: Set(group_id),
attribute_name: Set(attribute.name),
value: Set(attribute.value.into()),
});
} else {
return Err(DomainError::InternalError(format!(
"Attribute name {} doesn't exist in the group schema,
yet was attempted to be inserted in the database",
&attribute.name
)));
}
}
if !new_group_attributes.is_empty() {
model::GroupAttributes::insert_many(new_group_attributes)
.exec(transaction)
.await?;
}
Ok(group_id)
})
})
.await?)
}
#[instrument(skip(self), level = "debug", err)]
async fn delete_group(&self, group_id: GroupId) -> Result<()> {
let res = model::Group::delete_by_id(group_id)
.exec(&self.sql_pool)
.await?;
if res.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such group: '{:?}'",
group_id
)));
}
Ok(())
}
}
impl SqlBackendHandler {
async fn update_group_with_transaction(
request: UpdateGroupRequest,
transaction: &DatabaseTransaction,
) -> Result<()> {
let lower_display_name = request
.display_name
.as_ref()
.map(|s| s.as_str().to_lowercase());
let update_group = model::groups::ActiveModel {
group_id: Set(request.group_id),
display_name: request.display_name.map(Set).unwrap_or_default(),
lowercase_display_name: lower_display_name.map(Set).unwrap_or_default(),
..Default::default()
};
update_group.update(transaction).await?;
let mut update_group_attributes = Vec::new();
let mut remove_group_attributes = Vec::new();
let schema = Self::get_schema_with_transaction(transaction).await?;
for attribute in request.insert_attributes {
if schema
.group_attributes
.get_attribute_type(&attribute.name)
.is_some()
{
update_group_attributes.push(model::group_attributes::ActiveModel {
group_id: Set(request.group_id),
attribute_name: Set(attribute.name.to_owned()),
value: Set(attribute.value.into()),
});
} else {
return Err(DomainError::InternalError(format!(
"Group attribute name {} doesn't exist in the schema, yet was attempted to be inserted in the database",
&attribute.name
)));
}
}
for attribute in request.delete_attributes {
if schema
.group_attributes
.get_attribute_type(&attribute)
.is_some()
{
remove_group_attributes.push(attribute);
} else {
return Err(DomainError::InternalError(format!(
"Group attribute name {} doesn't exist in the schema, yet was attempted to be removed from the database",
attribute
)));
}
}
if !remove_group_attributes.is_empty() {
model::GroupAttributes::delete_many()
.filter(model::GroupAttributesColumn::GroupId.eq(request.group_id))
.filter(model::GroupAttributesColumn::AttributeName.is_in(remove_group_attributes))
.exec(transaction)
.await?;
}
if !update_group_attributes.is_empty() {
model::GroupAttributes::insert_many(update_group_attributes)
.on_conflict(
OnConflict::columns([
model::GroupAttributesColumn::GroupId,
model::GroupAttributesColumn::AttributeName,
])
.update_column(model::GroupAttributesColumn::Value)
.to_owned(),
)
.exec(transaction)
.await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql_backend_handler::tests::*;
use lldap_domain::{
requests::CreateAttributeRequest,
types::{Attribute, AttributeType, GroupName, UserId},
};
use lldap_domain_handlers::handler::{SchemaBackendHandler, SubStringFilter};
use pretty_assertions::assert_eq;
async fn get_group_ids(
handler: &SqlBackendHandler,
filters: Option<GroupRequestFilter>,
) -> Vec<GroupId> {
handler
.list_groups(filters)
.await
.unwrap()
.into_iter()
.map(|g| g.id)
.collect::<Vec<_>>()
}
async fn get_group_names(
handler: &SqlBackendHandler,
filters: Option<GroupRequestFilter>,
) -> Vec<GroupName> {
handler
.list_groups(filters)
.await
.unwrap()
.into_iter()
.map(|g| g.display_name)
.collect::<Vec<_>>()
}
#[tokio::test]
async fn test_list_groups_no_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_names(&fixture.handler, None).await,
vec![
"Best Group".into(),
"Empty Group".into(),
"Worst Group".into()
]
);
}
#[tokio::test]
async fn test_list_groups_simple_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_names(
&fixture.handler,
Some(GroupRequestFilter::Or(vec![
GroupRequestFilter::DisplayName("Empty Group".into()),
GroupRequestFilter::Member(UserId::new("bob")),
]))
)
.await,
vec!["Best Group".into(), "Empty Group".into()]
);
}
#[tokio::test]
async fn test_list_groups_case_insensitive_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_names(
&fixture.handler,
Some(GroupRequestFilter::DisplayName("eMpTy gRoup".into()),)
)
.await,
vec!["Empty Group".into()]
);
}
#[tokio::test]
async fn test_list_groups_negation() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(
&fixture.handler,
Some(GroupRequestFilter::And(vec![
GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName(
"value".into()
))),
GroupRequestFilter::GroupId(fixture.groups[0]),
]))
)
.await,
vec![fixture.groups[0]]
);
}
#[tokio::test]
async fn test_list_groups_substring_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(
&fixture.handler,
Some(GroupRequestFilter::DisplayNameSubString(SubStringFilter {
initial: Some("be".to_owned()),
any: vec!["sT".to_owned()],
final_: Some("P".to_owned()),
})),
)
.await,
// Best group
vec![fixture.groups[0]]
);
}
#[tokio::test]
async fn test_list_groups_other_filter() {
let fixture = TestFixture::new().await;
fixture
.handler
.add_group_attribute(CreateAttributeRequest {
name: "gid".into(),
attribute_type: AttributeType::Integer,
is_list: false,
is_visible: true,
is_editable: true,
})
.await
.unwrap();
fixture
.handler
.update_group(UpdateGroupRequest {
group_id: fixture.groups[0],
display_name: None,
delete_attributes: Vec::new(),
insert_attributes: vec![Attribute {
name: "gid".into(),
value: 512.into(),
}],
})
.await
.unwrap();
assert_eq!(
get_group_ids(
&fixture.handler,
Some(GroupRequestFilter::AttributeEquality(
AttributeName::from("gid"),
512.into(),
)),
)
.await,
vec![fixture.groups[0]]
);
}
#[tokio::test]
async fn test_get_group_details() {
let fixture = TestFixture::new().await;
let details = fixture
.handler
.get_group_details(fixture.groups[0])
.await
.unwrap();
assert_eq!(details.group_id, fixture.groups[0]);
assert_eq!(details.display_name, "Best Group".into());
assert_eq!(
get_group_ids(
&fixture.handler,
Some(GroupRequestFilter::Uuid(details.uuid))
)
.await,
vec![fixture.groups[0]]
);
}
#[tokio::test]
async fn test_update_group() {
let fixture = TestFixture::new().await;
fixture
.handler
.update_group(UpdateGroupRequest {
group_id: fixture.groups[0],
display_name: Some("Awesomest Group".into()),
delete_attributes: Vec::new(),
insert_attributes: Vec::new(),
})
.await
.unwrap();
let details = fixture
.handler
.get_group_details(fixture.groups[0])
.await
.unwrap();
assert_eq!(details.display_name, "Awesomest Group".into());
}
#[tokio::test]
async fn test_delete_group() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(&fixture.handler, None).await,
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
);
fixture
.handler
.delete_group(fixture.groups[0])
.await
.unwrap();
assert_eq!(
get_group_ids(&fixture.handler, None).await,
vec![fixture.groups[2], fixture.groups[1]]
);
}
#[tokio::test]
async fn test_create_group() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_ids(&fixture.handler, None).await,
vec![fixture.groups[0], fixture.groups[2], fixture.groups[1]]
);
fixture
.handler
.add_group_attribute(CreateAttributeRequest {
name: "new_attribute".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
is_editable: true,
})
.await
.unwrap();
let new_group_id = fixture
.handler
.create_group(CreateGroupRequest {
display_name: "New Group".into(),
attributes: vec![Attribute {
name: "new_attribute".into(),
value: "value".to_string().into(),
}],
})
.await
.unwrap();
let group_details = fixture
.handler
.get_group_details(new_group_id)
.await
.unwrap();
assert_eq!(group_details.display_name, "New Group".into());
assert_eq!(
group_details.attributes,
vec![Attribute {
name: "new_attribute".into(),
value: "value".to_string().into(),
}]
);
}
#[tokio::test]
async fn test_set_group_attributes() {
let fixture = TestFixture::new().await;
fixture
.handler
.add_group_attribute(CreateAttributeRequest {
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: false,
is_visible: true,
is_editable: true,
})
.await
.unwrap();
let group_id = fixture.groups[0];
let attributes = vec![Attribute {
name: "new_attribute".into(),
value: 42i64.into(),
}];
fixture
.handler
.update_group(UpdateGroupRequest {
group_id,
display_name: None,
delete_attributes: Vec::new(),
insert_attributes: attributes.clone(),
})
.await
.unwrap();
let details = fixture.handler.get_group_details(group_id).await.unwrap();
assert_eq!(details.attributes, attributes);
fixture
.handler
.update_group(UpdateGroupRequest {
group_id,
display_name: None,
delete_attributes: vec!["new_attribute".into()],
insert_attributes: Vec::new(),
})
.await
.unwrap();
let details = fixture.handler.get_group_details(group_id).await.unwrap();
assert_eq!(details.attributes, Vec::new());
}
#[tokio::test]
async fn test_create_group_duplicate_name() {
let fixture = TestFixture::new().await;
fixture
.handler
.create_group(CreateGroupRequest {
display_name: "New Group".into(),
..Default::default()
})
.await
.unwrap();
fixture
.handler
.create_group(CreateGroupRequest {
display_name: "neW group".into(),
..Default::default()
})
.await
.unwrap_err();
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,343 @@
use crate::SqlBackendHandler;
use async_trait::async_trait;
use base64::Engine;
use lldap_auth::opaque;
use lldap_domain::types::UserId;
use lldap_domain_handlers::handler::{BindRequest, LoginHandler};
use lldap_domain_model::{
error::{DomainError, Result},
model::{self, UserColumn},
};
use lldap_opaque_handler::{OpaqueHandler, login, registration};
use sea_orm::{ActiveModelTrait, ActiveValue, EntityTrait, QuerySelect};
use secstr::SecUtf8;
use tracing::{debug, info, instrument, warn};
type SqlOpaqueHandler = SqlBackendHandler;
#[instrument(skip_all, level = "debug", err, fields(username = %username.as_str()))]
fn passwords_match(
password_file_bytes: &[u8],
clear_password: &str,
opaque_setup: &opaque::server::ServerSetup,
username: &UserId,
) -> Result<()> {
use opaque::{client, server};
let mut rng = rand::rngs::OsRng;
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
.map_err(opaque::AuthenticationError::ProtocolError)?;
let server_login_start_result = server::login::start_login(
&mut rng,
opaque_setup,
Some(password_file),
client_login_start_result.message,
username,
)?;
client::login::finish_login(
client_login_start_result.state,
server_login_start_result.message,
)?;
Ok(())
}
impl SqlBackendHandler {
fn get_orion_secret_key(&self) -> Result<orion::aead::SecretKey> {
Ok(orion::aead::SecretKey::from_slice(
self.opaque_setup.keypair().private(),
)?)
}
#[instrument(skip(self), level = "debug", err)]
async fn get_password_file_for_user(&self, user_id: UserId) -> Result<Option<Vec<u8>>> {
// Fetch the previously registered password file from the DB.
Ok(model::User::find_by_id(user_id)
.select_only()
.column(UserColumn::PasswordHash)
.into_tuple::<(Option<Vec<u8>>,)>()
.one(&self.sql_pool)
.await?
.and_then(|u| u.0))
}
}
#[async_trait]
impl LoginHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug", err)]
async fn bind(&self, request: BindRequest) -> Result<()> {
if let Some(password_hash) = self
.get_password_file_for_user(request.name.clone())
.await?
{
info!(r#"Login attempt for "{}""#, &request.name);
if passwords_match(
&password_hash,
&request.password,
&self.opaque_setup,
&request.name,
)
.is_ok()
{
return Ok(());
}
} else {
debug!(
r#"User "{}" doesn't exist or has no password"#,
&request.name
);
}
Err(DomainError::AuthenticationError(format!(
r#"for user "{}""#,
request.name
)))
}
}
#[async_trait]
impl OpaqueHandler for SqlOpaqueHandler {
#[instrument(skip_all, level = "debug", err)]
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> {
let user_id = request.username;
info!(r#"OPAQUE login attempt for "{}""#, &user_id);
let maybe_password_file = self
.get_password_file_for_user(user_id.clone())
.await?
.map(|bytes| {
opaque::server::ServerRegistration::deserialize(&bytes).map_err(|_| {
DomainError::InternalError(format!("Corrupted password file for {}", &user_id))
})
})
.transpose()?;
let mut rng = rand::rngs::OsRng;
// Get the CredentialResponse for the user, or a dummy one if no user/no password.
let start_response = opaque::server::login::start_login(
&mut rng,
&self.opaque_setup,
maybe_password_file,
request.login_start_request,
&user_id,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = login::ServerData {
username: user_id,
server_login: start_response.state,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
Ok(login::ServerLoginStartResponse {
server_data: base64::engine::general_purpose::STANDARD.encode(encrypted_state),
credential_response: start_response.message,
})
}
#[instrument(skip_all, level = "debug", err)]
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
let secret_key = self.get_orion_secret_key()?;
let login::ServerData {
username,
server_login,
} = bincode::deserialize(&orion::aead::open(
&secret_key,
&base64::engine::general_purpose::STANDARD.decode(&request.server_data)?,
)?)?;
// Finish the login: this makes sure the client data is correct, and gives a session key we
// don't need.
match opaque::server::login::finish_login(server_login, request.credential_finalization) {
Ok(session) => {
info!(r#"OPAQUE login successful for "{}""#, &username);
let _ = session.session_key;
}
Err(e) => {
warn!(r#"OPAQUE login attempt failed for "{}""#, &username);
return Err(e.into());
}
};
Ok(username)
}
#[instrument(skip_all, level = "debug", err)]
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse> {
// Generate the server-side key and derive the data to send back.
let start_response = opaque::server::registration::start_registration(
&self.opaque_setup,
request.registration_start_request,
&request.username,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = registration::ServerData {
username: request.username,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
Ok(registration::ServerRegistrationStartResponse {
server_data: base64::engine::general_purpose::STANDARD.encode(encrypted_state),
registration_response: start_response.message,
})
}
#[instrument(skip_all, level = "debug", err)]
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()> {
let secret_key = self.get_orion_secret_key()?;
let registration::ServerData { username } = bincode::deserialize(&orion::aead::open(
&secret_key,
&base64::engine::general_purpose::STANDARD.decode(&request.server_data)?,
)?)?;
let password_file =
opaque::server::registration::get_password_file(request.registration_upload);
// Set the user password to the new password.
let user_update = model::users::ActiveModel {
user_id: ActiveValue::Set(username.clone()),
password_hash: ActiveValue::Set(Some(password_file.serialize())),
..Default::default()
};
user_update.update(&self.sql_pool).await?;
info!(r#"Successfully (re)set password for "{}""#, &username);
Ok(())
}
}
/// Convenience function to set a user's password.
#[instrument(skip_all, level = "debug", err, fields(username = %username.as_str()))]
pub async fn register_password(
opaque_handler: &SqlOpaqueHandler,
username: UserId,
password: &SecUtf8,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use registration::*;
let registration_start =
opaque::client::registration::start_registration(password.unsecure().as_bytes(), &mut rng)?;
let start_response = opaque_handler
.registration_start(ClientRegistrationStartRequest {
username,
registration_start_request: registration_start.message,
})
.await?;
let registration_finish = opaque::client::registration::finish_registration(
registration_start.state,
start_response.registration_response,
&mut rng,
)?;
opaque_handler
.registration_finish(ClientRegistrationFinishRequest {
server_data: start_response.server_data,
registration_upload: registration_finish.message,
})
.await
}
#[cfg(test)]
mod tests {
use self::opaque::server::generate_random_private_key;
use super::*;
use crate::sql_backend_handler::tests::{
get_initialized_db, insert_user, insert_user_no_password,
};
async fn attempt_login(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use login::*;
let login_start = opaque::client::login::start_login(password, &mut rng)?;
let start_response = opaque_handler
.login_start(ClientLoginStartRequest {
username: UserId::new(username),
login_start_request: login_start.message,
})
.await?;
let login_finish = opaque::client::login::finish_login(
login_start.state,
start_response.credential_response,
)?;
opaque_handler
.login_finish(ClientLoginFinishRequest {
server_data: start_response.server_data,
credential_finalization: login_finish.message,
})
.await?;
Ok(())
}
#[tokio::test]
async fn test_opaque_flow() -> Result<()> {
let sql_pool = get_initialized_db().await;
crate::logging::init_for_tests();
let backend_handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool);
insert_user_no_password(&backend_handler, "bob").await;
insert_user_no_password(&backend_handler, "john").await;
attempt_login(&backend_handler, "bob", "bob00")
.await
.unwrap_err();
register_password(
&backend_handler,
UserId::new("bob"),
&secstr::SecUtf8::from("bob00"),
)
.await?;
attempt_login(&backend_handler, "bob", "wrong_password")
.await
.unwrap_err();
attempt_login(&backend_handler, "bob", "bob00").await?;
Ok(())
}
#[tokio::test]
async fn test_bind_user() {
let sql_pool = get_initialized_db().await;
let handler = SqlOpaqueHandler::new(generate_random_private_key(), sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
handler
.bind(BindRequest {
name: UserId::new("bob"),
password: "bob00".to_string(),
})
.await
.unwrap();
handler
.bind(BindRequest {
name: UserId::new("andrew"),
password: "bob00".to_string(),
})
.await
.unwrap_err();
handler
.bind(BindRequest {
name: UserId::new("bob"),
password: "wrong_password".to_string(),
})
.await
.unwrap_err();
}
#[tokio::test]
async fn test_user_no_password() {
let sql_pool = get_initialized_db().await;
let handler = SqlBackendHandler::new(generate_random_private_key(), sql_pool.clone());
insert_user_no_password(&handler, "bob").await;
handler
.bind(BindRequest {
name: UserId::new("bob"),
password: "bob00".to_string(),
})
.await
.unwrap_err();
}
}
@@ -0,0 +1,417 @@
use crate::sql_backend_handler::SqlBackendHandler;
use async_trait::async_trait;
use lldap_domain::{
requests::CreateAttributeRequest,
schema::{AttributeList, AttributeSchema, Schema},
types::{AttributeName, LdapObjectClass},
};
use lldap_domain_handlers::handler::{ReadSchemaBackendHandler, SchemaBackendHandler};
use lldap_domain_model::{
error::{DomainError, Result},
model,
};
use sea_orm::{
ActiveModelTrait, DatabaseTransaction, EntityTrait, QueryOrder, Set, TransactionTrait,
};
#[async_trait]
impl ReadSchemaBackendHandler for SqlBackendHandler {
async fn get_schema(&self) -> Result<Schema> {
Ok(self
.sql_pool
.transaction::<_, Schema, DomainError>(|transaction| {
Box::pin(async move { Self::get_schema_with_transaction(transaction).await })
})
.await?)
}
}
#[async_trait]
impl SchemaBackendHandler for SqlBackendHandler {
async fn add_user_attribute(&self, request: CreateAttributeRequest) -> Result<()> {
let new_attribute = model::user_attribute_schema::ActiveModel {
attribute_name: Set(request.name),
attribute_type: Set(request.attribute_type),
is_list: Set(request.is_list),
is_user_visible: Set(request.is_visible),
is_user_editable: Set(request.is_editable),
is_hardcoded: Set(false),
};
new_attribute.insert(&self.sql_pool).await?;
Ok(())
}
async fn add_group_attribute(&self, request: CreateAttributeRequest) -> Result<()> {
let new_attribute = model::group_attribute_schema::ActiveModel {
attribute_name: Set(request.name),
attribute_type: Set(request.attribute_type),
is_list: Set(request.is_list),
is_group_visible: Set(request.is_visible),
is_group_editable: Set(request.is_editable),
is_hardcoded: Set(false),
};
new_attribute.insert(&self.sql_pool).await?;
Ok(())
}
async fn delete_user_attribute(&self, name: &AttributeName) -> Result<()> {
model::UserAttributeSchema::delete_by_id(name.clone())
.exec(&self.sql_pool)
.await?;
Ok(())
}
async fn delete_group_attribute(&self, name: &AttributeName) -> Result<()> {
model::GroupAttributeSchema::delete_by_id(name.clone())
.exec(&self.sql_pool)
.await?;
Ok(())
}
async fn add_user_object_class(&self, name: &LdapObjectClass) -> Result<()> {
let mut name_key = name.to_string();
name_key.make_ascii_lowercase();
model::user_object_classes::ActiveModel {
lower_object_class: Set(name_key),
object_class: Set(name.clone()),
}
.insert(&self.sql_pool)
.await?;
Ok(())
}
async fn add_group_object_class(&self, name: &LdapObjectClass) -> Result<()> {
let mut name_key = name.to_string();
name_key.make_ascii_lowercase();
model::group_object_classes::ActiveModel {
lower_object_class: Set(name_key),
object_class: Set(name.clone()),
}
.insert(&self.sql_pool)
.await?;
Ok(())
}
async fn delete_user_object_class(&self, name: &LdapObjectClass) -> Result<()> {
model::UserObjectClasses::delete_by_id(name.as_str().to_ascii_lowercase())
.exec(&self.sql_pool)
.await?;
Ok(())
}
async fn delete_group_object_class(&self, name: &LdapObjectClass) -> Result<()> {
model::GroupObjectClasses::delete_by_id(name.as_str().to_ascii_lowercase())
.exec(&self.sql_pool)
.await?;
Ok(())
}
}
impl SqlBackendHandler {
pub(crate) async fn get_schema_with_transaction(
transaction: &DatabaseTransaction,
) -> Result<Schema> {
Ok(Schema {
user_attributes: AttributeList {
attributes: Self::get_user_attributes(transaction).await?,
},
group_attributes: AttributeList {
attributes: Self::get_group_attributes(transaction).await?,
},
extra_user_object_classes: Self::get_user_object_classes(transaction).await?,
extra_group_object_classes: Self::get_group_object_classes(transaction).await?,
})
}
async fn get_user_attributes(
transaction: &DatabaseTransaction,
) -> Result<Vec<AttributeSchema>> {
Ok(model::UserAttributeSchema::find()
.order_by_asc(model::UserAttributeSchemaColumn::AttributeName)
.all(transaction)
.await?
.into_iter()
.map(|m| m.into())
.collect())
}
async fn get_group_attributes(
transaction: &DatabaseTransaction,
) -> Result<Vec<AttributeSchema>> {
Ok(model::GroupAttributeSchema::find()
.order_by_asc(model::GroupAttributeSchemaColumn::AttributeName)
.all(transaction)
.await?
.into_iter()
.map(|m| m.into())
.collect())
}
async fn get_user_object_classes(
transaction: &DatabaseTransaction,
) -> Result<Vec<LdapObjectClass>> {
Ok(model::UserObjectClasses::find()
.order_by_asc(model::UserObjectClassesColumn::ObjectClass)
.all(transaction)
.await?
.into_iter()
.map(Into::into)
.collect())
}
async fn get_group_object_classes(
transaction: &DatabaseTransaction,
) -> Result<Vec<LdapObjectClass>> {
Ok(model::GroupObjectClasses::find()
.order_by_asc(model::GroupObjectClassesColumn::ObjectClass)
.all(transaction)
.await?
.into_iter()
.map(Into::into)
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql_backend_handler::tests::*;
use lldap_domain::requests::UpdateUserRequest;
use lldap_domain::schema::AttributeList;
use lldap_domain::types::{Attribute, AttributeType};
use lldap_domain_handlers::handler::{UserBackendHandler, UserRequestFilter};
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_default_schema() {
let fixture = TestFixture::new().await;
assert_eq!(
fixture.handler.get_schema().await.unwrap(),
Schema {
user_attributes: AttributeList {
attributes: vec![
AttributeSchema {
name: "avatar".into(),
attribute_type: AttributeType::JpegPhoto,
is_list: false,
is_visible: true,
is_editable: true,
is_hardcoded: true,
is_readonly: false,
},
AttributeSchema {
name: "first_name".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
is_editable: true,
is_hardcoded: true,
is_readonly: false,
},
AttributeSchema {
name: "last_name".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
is_editable: true,
is_hardcoded: true,
is_readonly: false,
}
]
},
group_attributes: AttributeList {
attributes: Vec::new()
},
extra_user_object_classes: Vec::new(),
extra_group_object_classes: Vec::new(),
}
);
}
#[tokio::test]
async fn test_user_attribute_add_and_delete() {
let fixture = TestFixture::new().await;
let new_attribute = CreateAttributeRequest {
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: true,
is_visible: false,
is_editable: false,
};
fixture
.handler
.add_user_attribute(new_attribute)
.await
.unwrap();
let expected_value = AttributeSchema {
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: true,
is_visible: false,
is_editable: false,
is_hardcoded: false,
is_readonly: false,
};
assert!(
fixture
.handler
.get_schema()
.await
.unwrap()
.user_attributes
.attributes
.contains(&expected_value)
);
fixture
.handler
.delete_user_attribute(&"new_attribute".into())
.await
.unwrap();
assert!(
!fixture
.handler
.get_schema()
.await
.unwrap()
.user_attributes
.attributes
.contains(&expected_value)
);
}
#[tokio::test]
async fn test_user_attribute_present_filter() {
let fixture = TestFixture::new().await;
let new_attribute = CreateAttributeRequest {
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: true,
is_visible: false,
is_editable: false,
};
fixture
.handler
.add_user_attribute(new_attribute)
.await
.unwrap();
fixture
.handler
.update_user(UpdateUserRequest {
user_id: "bob".into(),
insert_attributes: vec![Attribute {
name: "new_attribute".into(),
value: vec![3].into(),
}],
..Default::default()
})
.await
.unwrap();
let users = get_user_names(
&fixture.handler,
Some(UserRequestFilter::CustomAttributePresent(
"new_attribute".into(),
)),
)
.await;
assert_eq!(users, vec!["bob"]);
}
#[tokio::test]
async fn test_group_attribute_add_and_delete() {
let fixture = TestFixture::new().await;
let new_attribute = CreateAttributeRequest {
name: "NeW_aTTribute".into(),
attribute_type: AttributeType::JpegPhoto,
is_list: false,
is_visible: true,
is_editable: false,
};
fixture
.handler
.add_group_attribute(new_attribute)
.await
.unwrap();
let expected_value = AttributeSchema {
name: "new_attribute".into(),
attribute_type: AttributeType::JpegPhoto,
is_list: false,
is_visible: true,
is_editable: false,
is_hardcoded: false,
is_readonly: false,
};
assert!(
fixture
.handler
.get_schema()
.await
.unwrap()
.group_attributes
.attributes
.contains(&expected_value)
);
fixture
.handler
.delete_group_attribute(&"new_attriBUte".into())
.await
.unwrap();
assert!(
!fixture
.handler
.get_schema()
.await
.unwrap()
.group_attributes
.attributes
.contains(&expected_value)
);
}
#[tokio::test]
async fn test_user_object_class_add_and_delete() {
let fixture = TestFixture::new().await;
let new_object_class = LdapObjectClass::new("newObjectClass");
fixture
.handler
.add_user_object_class(&new_object_class)
.await
.unwrap();
assert_eq!(
fixture
.handler
.get_schema()
.await
.unwrap()
.extra_user_object_classes,
vec![new_object_class.clone()]
);
fixture
.handler
.add_user_object_class(&LdapObjectClass::new("newobjEctclass"))
.await
.expect_err("Should not be able to add the same object class twice");
assert_eq!(
fixture
.handler
.get_schema()
.await
.unwrap()
.extra_user_object_classes,
vec![new_object_class.clone()]
);
fixture
.handler
.delete_user_object_class(&new_object_class)
.await
.unwrap();
assert!(
fixture
.handler
.get_schema()
.await
.unwrap()
.extra_user_object_classes
.is_empty()
);
}
}
@@ -0,0 +1,566 @@
use crate::sql_migrations::{Metadata, get_schema_version, migrate_from_version, upgrade_to_v1};
use sea_orm::{
ConnectionTrait, DeriveValueType, Iden, QueryResult, TryGetable, Value, sea_query::Query,
};
use serde::{Deserialize, Serialize};
pub type DbConnection = sea_orm::DatabaseConnection;
#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord, DeriveValueType)]
pub struct SchemaVersion(pub i16);
pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(10);
#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord)]
pub struct PrivateKeyHash(pub [u8; 32]);
impl TryGetable for PrivateKeyHash {
fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, sea_orm::TryGetError> {
let index = format!("{pre}{col}");
Self::try_get_by(res, index.as_str())
}
fn try_get_by_index(res: &QueryResult, index: usize) -> Result<Self, sea_orm::TryGetError> {
Self::try_get_by(res, index)
}
fn try_get_by<I: sea_orm::ColIdx>(
res: &QueryResult,
index: I,
) -> Result<Self, sea_orm::TryGetError> {
Ok(PrivateKeyHash(
std::convert::TryInto::<[u8; 32]>::try_into(res.try_get_by::<Vec<u8>, I>(index)?)
.unwrap(),
))
}
}
impl From<PrivateKeyHash> for Value {
fn from(val: PrivateKeyHash) -> Self {
Self::from(val.0.to_vec())
}
}
pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> {
let version = {
if let Some(version) = get_schema_version(pool).await {
version
} else {
upgrade_to_v1(pool).await?;
SchemaVersion(1)
}
};
migrate_from_version(pool, version, LAST_SCHEMA_VERSION).await?;
Ok(())
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum ConfigLocation {
ConfigFile(String),
EnvironmentVariable(String),
CommandLine,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum PrivateKeyLocation {
KeySeed(ConfigLocation),
KeyFile(ConfigLocation, std::ffi::OsString),
Default,
Tests,
}
#[derive(Debug)]
pub struct PrivateKeyInfo {
pub private_key_hash: PrivateKeyHash,
pub private_key_location: PrivateKeyLocation,
}
pub async fn get_private_key_info(pool: &DbConnection) -> anyhow::Result<Option<PrivateKeyInfo>> {
let result = pool
.query_one(
pool.get_database_backend().build(
Query::select()
.column(Metadata::PrivateKeyHash)
.column(Metadata::PrivateKeyLocation)
.from(Metadata::Table),
),
)
.await?;
let result = match result {
None => return Ok(None),
Some(r) => r,
};
if let Ok(hash) = result.try_get("", &Metadata::PrivateKeyHash.to_string()) {
Ok(Some(PrivateKeyInfo {
private_key_hash: hash,
private_key_location: serde_json::from_str(
&result.try_get::<String>("", &Metadata::PrivateKeyLocation.to_string())?,
)?,
}))
} else {
Ok(None)
}
}
pub async fn set_private_key_info(pool: &DbConnection, info: PrivateKeyInfo) -> anyhow::Result<()> {
pool.execute(
pool.get_database_backend().build(
Query::update()
.table(Metadata::Table)
.value(Metadata::PrivateKeyHash, Value::from(info.private_key_hash))
.value(
Metadata::PrivateKeyLocation,
Value::from(serde_json::to_string(&info.private_key_location).unwrap()),
),
),
)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use crate::sql_migrations;
use lldap_domain::types::{GroupId, JpegPhoto, Serialized, Uuid};
use pretty_assertions::assert_eq;
use super::*;
use chrono::prelude::*;
use sea_orm::{ConnectionTrait, Database, DbBackend, FromQueryResult};
use tracing::error;
async fn get_in_memory_db() -> DbConnection {
let mut sql_opt = sea_orm::ConnectOptions::new("sqlite::memory:".to_owned());
sql_opt.max_connections(1).sqlx_logging(false);
Database::connect(sql_opt).await.unwrap()
}
fn raw_statement(sql: &str) -> sea_orm::Statement {
sea_orm::Statement::from_string(DbBackend::Sqlite, sql.to_owned())
}
#[tokio::test]
async fn test_init_table() {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users
(user_id, email, lowercase_email, display_name, creation_date, password_hash, uuid)
VALUES ("bôb", "böb@bob.bob", "böb@bob.bob", "Bob Bobbersön", "1970-01-01 00:00:00", "bob00", "abc")"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO user_attributes
(user_attribute_user_id, user_attribute_name, user_attribute_value)
VALUES ("bôb", "first_name", "Bob")"#,
))
.await
.unwrap();
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortUserDetails {
display_name: String,
creation_date: chrono::NaiveDateTime,
}
let result = ShortUserDetails::find_by_statement(raw_statement(
r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#,
))
.one(&sql_pool)
.await
.unwrap()
.unwrap();
assert_eq!(
result,
ShortUserDetails {
display_name: "Bob Bobbersön".to_owned(),
creation_date: Utc.timestamp_opt(0, 0).unwrap().naive_utc(),
}
);
}
#[tokio::test]
async fn test_already_init_table() {
crate::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
init_table(&sql_pool).await.unwrap();
}
#[tokio::test]
async fn test_migrate_tables() {
crate::logging::init_for_tests();
// Test that we add the column creation_date to groups and uuid to users and groups.
let sql_pool = get_in_memory_db().await;
sql_pool
.execute(raw_statement(
r#"CREATE TABLE users ( user_id TEXT PRIMARY KEY, display_name TEXT, first_name TEXT NOT NULL, last_name TEXT, avatar BLOB, creation_date TEXT, email TEXT);"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, display_name, first_name, creation_date, email)
VALUES ("bôb", "", "", "1970-01-01 00:00:00", "bob@bob.com")"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, display_name, first_name, creation_date, email)
VALUES ("john", "John Doe", "John", "1971-01-01 00:00:00", "bob2@bob.com")"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"CREATE TABLE groups ( group_id INTEGER PRIMARY KEY, display_name TEXT );"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name)
VALUES ("lldap_admin"), ("lldap_readonly")"#,
))
.await
.unwrap();
init_table(&sql_pool).await.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name, creation_date, uuid)
VALUES ("test", "1970-01-01 00:00:00", "abc")"#,
))
.await
.unwrap();
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct SimpleUser {
display_name: Option<String>,
uuid: Uuid,
}
assert_eq!(
SimpleUser::find_by_statement(raw_statement(
r#"SELECT display_name, uuid FROM users ORDER BY display_name"#
))
.all(&sql_pool)
.await
.unwrap(),
vec![
SimpleUser {
display_name: None,
uuid: lldap_domain::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04")
},
SimpleUser {
display_name: Some("John Doe".to_owned()),
uuid: lldap_domain::uuid!("986765a5-3f03-389e-b47b-536b2d6e1bec")
}
]
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct UserAttribute {
user_attribute_user_id: String,
user_attribute_name: String,
user_attribute_value: Serialized,
}
assert_eq!(
UserAttribute::find_by_statement(raw_statement(
r#"SELECT user_attribute_user_id, user_attribute_name, user_attribute_value FROM user_attributes ORDER BY user_attribute_user_id, user_attribute_value"#
))
.all(&sql_pool)
.await
.unwrap(),
vec![
UserAttribute {
user_attribute_user_id: "john".to_owned(),
user_attribute_name: "first_name".to_owned(),
user_attribute_value: Serialized::from("John"),
}
]
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortGroupDetails {
group_id: GroupId,
display_name: String,
}
assert_eq!(
ShortGroupDetails::find_by_statement(raw_statement(
r#"SELECT group_id, display_name, creation_date FROM groups"#
))
.all(&sql_pool)
.await
.unwrap(),
vec![
ShortGroupDetails {
group_id: GroupId(1),
display_name: "lldap_admin".to_string()
},
ShortGroupDetails {
group_id: GroupId(2),
display_name: "lldap_password_manager".to_string()
},
ShortGroupDetails {
group_id: GroupId(3),
display_name: "test".to_string()
}
]
);
assert_eq!(
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
sql_migrations::JustSchemaVersion {
version: LAST_SCHEMA_VERSION
}
);
}
#[tokio::test]
async fn test_migration_to_v4() {
crate::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
upgrade_to_v1(&sql_pool).await.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(3))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, email, display_name, first_name, creation_date, uuid)
VALUES ("bob", "bob@bob.com", "", "", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, email, display_name, first_name, creation_date, uuid)
VALUES ("bob2", "bob@bob.com", "", "", "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#,
))
.await
.unwrap();
error!(
"{}",
migrate_from_version(&sql_pool, SchemaVersion(3), SchemaVersion(4))
.await
.expect_err("migration should fail")
);
assert_eq!(
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
sql_migrations::JustSchemaVersion {
version: SchemaVersion(3)
}
);
sql_pool
.execute(raw_statement(
r#"UPDATE users SET email = "new@bob.com" WHERE user_id = "bob2""#,
))
.await
.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(3), SchemaVersion(4))
.await
.unwrap();
assert_eq!(
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
sql_migrations::JustSchemaVersion {
version: SchemaVersion(4)
}
);
}
#[tokio::test]
async fn test_migration_to_v5() {
crate::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
upgrade_to_v1(&sql_pool).await.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(4))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, email, creation_date, uuid)
VALUES ("bob", "bob@bob.com", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
))
.await
.unwrap();
sql_pool
.execute(sea_orm::Statement::from_sql_and_values(DbBackend::Sqlite,
r#"INSERT INTO users (user_id, email, display_name, first_name, last_name, avatar, creation_date, uuid)
VALUES ("bob2", "bob2@bob.com", "display bob", "first bob", "last bob", $1, "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#, [JpegPhoto::for_tests().into()]),
)
.await
.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(4), SchemaVersion(5))
.await
.unwrap();
assert_eq!(
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
sql_migrations::JustSchemaVersion {
version: SchemaVersion(5)
}
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
pub struct UserV5 {
user_id: String,
email: String,
display_name: Option<String>,
}
assert_eq!(
UserV5::find_by_statement(raw_statement(
r#"SELECT user_id, email, display_name FROM users ORDER BY user_id ASC"#
))
.all(&sql_pool)
.await
.unwrap(),
vec![
UserV5 {
user_id: "bob".to_owned(),
email: "bob@bob.com".to_owned(),
display_name: None
},
UserV5 {
user_id: "bob2".to_owned(),
email: "bob2@bob.com".to_owned(),
display_name: Some("display bob".to_owned())
},
]
);
sql_pool
.execute(raw_statement(r#"SELECT first_name FROM users"#))
.await
.unwrap_err();
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
pub struct UserAttribute {
user_attribute_user_id: String,
user_attribute_name: String,
user_attribute_value: Serialized,
}
assert_eq!(
UserAttribute::find_by_statement(raw_statement(r#"SELECT * FROM user_attributes ORDER BY user_attribute_user_id, user_attribute_name ASC"#))
.all(&sql_pool)
.await
.unwrap(),
vec![
UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "avatar".to_owned(), user_attribute_value: Serialized::from(&JpegPhoto::for_tests()) },
UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "first_name".to_owned(), user_attribute_value: Serialized::from("first bob") },
UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "last_name".to_owned(), user_attribute_value: Serialized::from("last bob") },
]
);
}
#[tokio::test]
async fn test_migration_to_v6() {
crate::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
upgrade_to_v1(&sql_pool).await.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(5))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, email, display_name, creation_date, uuid)
VALUES ("bob", "BOb@bob.com", "", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name, creation_date, uuid)
VALUES ("BestGroup", "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#,
))
.await
.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(5), SchemaVersion(6))
.await
.unwrap();
assert_eq!(
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
sql_migrations::JustSchemaVersion {
version: SchemaVersion(6)
}
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortUserDetails {
email: String,
lowercase_email: String,
}
let result = ShortUserDetails::find_by_statement(raw_statement(
r#"SELECT email, lowercase_email FROM users WHERE user_id = "bob""#,
))
.one(&sql_pool)
.await
.unwrap()
.unwrap();
assert_eq!(
result,
ShortUserDetails {
email: "BOb@bob.com".to_owned(),
lowercase_email: "bob@bob.com".to_owned(),
}
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortGroupDetails {
display_name: String,
lowercase_display_name: String,
}
let result = ShortGroupDetails::find_by_statement(raw_statement(
r#"SELECT display_name, lowercase_display_name FROM groups"#,
))
.one(&sql_pool)
.await
.unwrap()
.unwrap();
assert_eq!(
result,
ShortGroupDetails {
display_name: "BestGroup".to_owned(),
lowercase_display_name: "bestgroup".to_owned(),
}
);
}
#[tokio::test]
async fn test_too_high_version() {
let sql_pool = get_in_memory_db().await;
sql_pool
.execute(raw_statement(
r#"CREATE TABLE metadata ( version INTEGER);"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO metadata (version)
VALUES (127)"#,
))
.await
.unwrap();
assert!(init_table(&sql_pool).await.is_err());
}
}
File diff suppressed because it is too large Load Diff