ldap: normalize base DN in LdapInfo, reduce memory usage

By making it a &'static, we can have a single allocation for all the threads/async contexts.

This also normalizes the whitespace from the user input; a trailing \n can cause weird issues with clients
This commit is contained in:
Valentin Tolmer
2025-09-17 00:47:53 +02:00
committed by nitnelave
parent f7fe0c6ea0
commit 8a803bfb11
5 changed files with 63 additions and 59 deletions
+27
View File
@@ -300,6 +300,23 @@ pub struct LdapInfo {
pub ignored_group_attributes: Vec<AttributeName>, pub ignored_group_attributes: Vec<AttributeName>,
} }
impl LdapInfo {
pub fn new(
base_dn: &str,
ignored_user_attributes: Vec<AttributeName>,
ignored_group_attributes: Vec<AttributeName>,
) -> LdapResult<Self> {
let base_dn = parse_distinguished_name(&base_dn.to_ascii_lowercase())?;
let base_dn_str = join(base_dn.iter().map(|(k, v)| format!("{k}={v}")), ",");
Ok(Self {
base_dn,
base_dn_str,
ignored_user_attributes,
ignored_group_attributes,
})
}
}
pub fn get_custom_attribute( pub fn get_custom_attribute(
attributes: &[Attribute], attributes: &[Attribute],
attribute_name: &AttributeName, attribute_name: &AttributeName,
@@ -521,4 +538,14 @@ mod tests {
parsed_dn parsed_dn
); );
} }
#[test]
fn test_whitespace_in_ldap_info() {
assert_eq!(
LdapInfo::new(" ou=people, dc =example, dc=com \n", vec![], vec![])
.unwrap()
.base_dn_str,
"ou=people,dc=example,dc=com"
);
}
} }
+14 -24
View File
@@ -2,7 +2,7 @@ use crate::{
compare, compare,
core::{ core::{
error::{LdapError, LdapResult}, error::{LdapError, LdapResult},
utils::{LdapInfo, parse_distinguished_name}, utils::LdapInfo,
}, },
create, delete, modify, create, delete, modify,
password::{self, do_password_modification}, password::{self, do_password_modification},
@@ -18,7 +18,7 @@ use ldap3_proto::proto::{
}; };
use lldap_access_control::AccessControlledBackendHandler; use lldap_access_control::AccessControlledBackendHandler;
use lldap_auth::access_control::ValidationResults; use lldap_auth::access_control::ValidationResults;
use lldap_domain::{public_schema::PublicSchema, types::AttributeName}; use lldap_domain::public_schema::PublicSchema;
use lldap_domain_handlers::handler::{BackendHandler, LoginHandler, ReadSchemaBackendHandler}; use lldap_domain_handlers::handler::{BackendHandler, LoginHandler, ReadSchemaBackendHandler};
use lldap_opaque_handler::OpaqueHandler; use lldap_opaque_handler::OpaqueHandler;
use tracing::{debug, instrument}; use tracing::{debug, instrument};
@@ -59,7 +59,7 @@ pub(crate) fn make_modify_response(code: LdapResultCode, message: String) -> Lda
pub struct LdapHandler<Backend> { pub struct LdapHandler<Backend> {
user_info: Option<ValidationResults>, user_info: Option<ValidationResults>,
backend_handler: AccessControlledBackendHandler<Backend>, backend_handler: AccessControlledBackendHandler<Backend>,
ldap_info: LdapInfo, ldap_info: &'static LdapInfo,
session_uuid: uuid::Uuid, session_uuid: uuid::Uuid,
} }
@@ -89,23 +89,13 @@ enum Credentials<'s> {
impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend> { impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend> {
pub fn new( pub fn new(
backend_handler: AccessControlledBackendHandler<Backend>, backend_handler: AccessControlledBackendHandler<Backend>,
mut ldap_base_dn: String, ldap_info: &'static LdapInfo,
ignored_user_attributes: Vec<AttributeName>,
ignored_group_attributes: Vec<AttributeName>,
session_uuid: uuid::Uuid, session_uuid: uuid::Uuid,
) -> Self { ) -> Self {
ldap_base_dn.make_ascii_lowercase();
Self { Self {
user_info: None, user_info: None,
backend_handler, backend_handler,
ldap_info: LdapInfo { ldap_info,
base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
panic!("Invalid value for ldap_base_dn in configuration: {ldap_base_dn}")
}),
base_dn_str: ldap_base_dn,
ignored_user_attributes,
ignored_group_attributes,
},
session_uuid, session_uuid,
} }
} }
@@ -114,9 +104,9 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
pub fn new_for_tests(backend_handler: Backend, ldap_base_dn: &str) -> Self { pub fn new_for_tests(backend_handler: Backend, ldap_base_dn: &str) -> Self {
Self::new( Self::new(
AccessControlledBackendHandler::new(backend_handler), AccessControlledBackendHandler::new(backend_handler),
ldap_base_dn.to_string(), Box::leak(Box::new(
vec![], LdapInfo::new(ldap_base_dn, Vec::new(), Vec::new()).unwrap(),
vec![], )),
uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(), uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
) )
} }
@@ -171,13 +161,13 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
let backend_handler = self let backend_handler = self
.backend_handler .backend_handler
.get_user_restricted_lister_handler(user_info); .get_user_restricted_lister_handler(user_info);
search::do_search(&backend_handler, &self.ldap_info, request).await search::do_search(&backend_handler, self.ldap_info, request).await
} }
#[instrument(skip_all, level = "debug", fields(dn = %request.dn))] #[instrument(skip_all, level = "debug", fields(dn = %request.dn))]
pub async fn do_bind(&mut self, request: &LdapBindRequest) -> Vec<LdapOp> { pub async fn do_bind(&mut self, request: &LdapBindRequest) -> Vec<LdapOp> {
let (code, message) = let (code, message) =
match password::do_bind(&self.ldap_info, request, self.get_login_handler()).await { match password::do_bind(self.ldap_info, request, self.get_login_handler()).await {
Ok(user_id) => { Ok(user_id) => {
self.user_info = self self.user_info = self
.backend_handler .backend_handler
@@ -211,7 +201,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
}; };
do_password_modification( do_password_modification(
credentials, credentials,
&self.ldap_info, self.ldap_info,
&self.backend_handler, &self.backend_handler,
self.get_opaque_handler(), self.get_opaque_handler(),
&password_request, &password_request,
@@ -257,7 +247,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
self.backend_handler self.backend_handler
.get_readable_handler(credentials, &user_id) .get_readable_handler(credentials, &user_id)
}, },
&self.ldap_info, self.ldap_info,
credentials, credentials,
request, request,
) )
@@ -275,7 +265,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
code: LdapResultCode::InsufficentAccessRights, code: LdapResultCode::InsufficentAccessRights,
message: "Unauthorized write".to_string(), message: "Unauthorized write".to_string(),
})?; })?;
create::create_user_or_group(backend_handler, &self.ldap_info, request).await create::create_user_or_group(backend_handler, self.ldap_info, request).await
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
@@ -288,7 +278,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
code: LdapResultCode::InsufficentAccessRights, code: LdapResultCode::InsufficentAccessRights,
message: "Unauthorized write".to_string(), message: "Unauthorized write".to_string(),
})?; })?;
delete::delete_user_or_group(backend_handler, &self.ldap_info, request).await delete::delete_user_or_group(backend_handler, self.ldap_info, request).await
} }
#[instrument(skip_all, level = "debug")] #[instrument(skip_all, level = "debug")]
+1 -1
View File
@@ -7,7 +7,7 @@ pub(crate) mod modify;
pub(crate) mod password; pub(crate) mod password;
pub(crate) mod search; pub(crate) mod search;
pub use core::utils::{UserFieldType, map_group_field, map_user_field}; pub use core::utils::{LdapInfo, UserFieldType, map_group_field, map_user_field};
pub use handler::LdapHandler; pub use handler::LdapHandler;
pub use core::group::get_default_group_object_classes; pub use core::group::get_default_group_object_classes;
+1 -2
View File
@@ -17,7 +17,7 @@ use lldap_domain::{
public_schema::PublicSchema, public_schema::PublicSchema,
types::{Group, UserAndGroups}, types::{Group, UserAndGroups},
}; };
use tracing::{debug, instrument, warn}; use tracing::{debug, warn};
#[derive(Debug)] #[derive(Debug)]
enum SearchScope { enum SearchScope {
@@ -396,7 +396,6 @@ async fn do_search_internal(
}) })
} }
#[instrument(skip_all, level = "debug")]
pub async fn do_search( pub async fn do_search(
backend_handler: &impl UserAndGroupListerBackendHandler, backend_handler: &impl UserAndGroupListerBackendHandler,
ldap_info: &LdapInfo, ldap_info: &LdapInfo,
+20 -32
View File
@@ -5,9 +5,8 @@ use actix_service::{ServiceFactoryExt, fn_service};
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use ldap3_proto::{LdapCodec, control::LdapControl, proto::LdapMsg, proto::LdapOp}; use ldap3_proto::{LdapCodec, control::LdapControl, proto::LdapMsg, proto::LdapOp};
use lldap_access_control::AccessControlledBackendHandler; use lldap_access_control::AccessControlledBackendHandler;
use lldap_domain::types::AttributeName;
use lldap_domain_handlers::handler::{BackendHandler, LoginHandler}; use lldap_domain_handlers::handler::{BackendHandler, LoginHandler};
use lldap_ldap::LdapHandler; use lldap_ldap::{LdapHandler, LdapInfo};
use lldap_opaque_handler::OpaqueHandler; use lldap_opaque_handler::OpaqueHandler;
use rustls::PrivateKey; use rustls::PrivateKey;
use tokio_rustls::TlsAcceptor as RustlsTlsAcceptor; use tokio_rustls::TlsAcceptor as RustlsTlsAcceptor;
@@ -71,9 +70,7 @@ where
async fn handle_ldap_stream<Stream, Backend>( async fn handle_ldap_stream<Stream, Backend>(
stream: Stream, stream: Stream,
backend_handler: Backend, backend_handler: Backend,
ldap_base_dn: String, ldap_info: &'static LdapInfo,
ignored_user_attributes: Vec<AttributeName>,
ignored_group_attributes: Vec<AttributeName>,
) -> Result<Stream> ) -> Result<Stream>
where where
Backend: BackendHandler + LoginHandler + OpaqueHandler + 'static, Backend: BackendHandler + LoginHandler + OpaqueHandler + 'static,
@@ -88,9 +85,7 @@ where
let session_uuid = Uuid::new_v4(); let session_uuid = Uuid::new_v4();
let mut session = LdapHandler::new( let mut session = LdapHandler::new(
AccessControlledBackendHandler::new(backend_handler), AccessControlledBackendHandler::new(backend_handler),
ldap_base_dn, ldap_info,
ignored_user_attributes,
ignored_group_attributes,
session_uuid, session_uuid,
); );
@@ -170,9 +165,19 @@ where
{ {
let context = ( let context = (
backend_handler, backend_handler,
config.ldap_base_dn.clone(), Box::leak(Box::new(
config.ignored_user_attributes.clone(), LdapInfo::new(
config.ignored_group_attributes.clone(), &config.ldap_base_dn,
config.ignored_user_attributes.clone(),
config.ignored_group_attributes.clone(),
)
.with_context(|| {
format!(
"Invalid value for ldap_base_dn in configuration: {}",
&config.ldap_base_dn
)
})?,
)) as &'static LdapInfo,
); );
let context_for_tls = context.clone(); let context_for_tls = context.clone();
@@ -182,15 +187,8 @@ where
fn_service(move |stream: TcpStream| { fn_service(move |stream: TcpStream| {
let context = context.clone(); let context = context.clone();
async move { async move {
let (handler, base_dn, ignored_user_attributes, ignored_group_attributes) = context; let (handler, ldap_info) = context;
handle_ldap_stream( handle_ldap_stream(stream, handler, ldap_info).await
stream,
handler,
base_dn,
ignored_user_attributes,
ignored_group_attributes,
)
.await
} }
}) })
.map_err(|err: anyhow::Error| error!("[LDAP] Service Error: {:#}", err)) .map_err(|err: anyhow::Error| error!("[LDAP] Service Error: {:#}", err))
@@ -211,19 +209,9 @@ where
fn_service(move |stream: TcpStream| { fn_service(move |stream: TcpStream| {
let tls_context = tls_context.clone(); let tls_context = tls_context.clone();
async move { async move {
let ( let ((handler, ldap_info), tls_acceptor) = tls_context;
(handler, base_dn, ignored_user_attributes, ignored_group_attributes),
tls_acceptor,
) = tls_context;
let tls_stream = tls_acceptor.accept(stream).await?; let tls_stream = tls_acceptor.accept(stream).await?;
handle_ldap_stream( handle_ldap_stream(tls_stream, handler, ldap_info).await
tls_stream,
handler,
base_dn,
ignored_user_attributes,
ignored_group_attributes,
)
.await
} }
}) })
.map_err(|err: anyhow::Error| error!("[LDAPS] Service Error: {:#}", err)) .map_err(|err: anyhow::Error| error!("[LDAPS] Service Error: {:#}", err))