use std::sync::Arc;
use std::time::Instant;
use actix_web::cookie::time::Duration;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
use reqwest::header::HeaderValue;
use reqwest::Url;
use secrecy::{ExposeSecret, Secret};
use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
use crate::keycloak_api::dtos::UserDto;
use crate::model::dto::{PageParameters, UserSearchParameters};
use super::traits::KeycloakApi;
use super::{errors::KeycloakApiError, traits::Result};
pub const DEFAULT_PER_PAGE: i32 = 10;
pub const MIN_PAGE: i32 = 1;
pub const MIN_PER_PAGE: i32 = 1;
#[derive(Clone)]
pub struct Api {
base_url: Url,
auth_data: Arc<Mutex<Option<AuthData>>>,
token_url: Url,
client_id: String,
client_secret: Secret<String>,
}
#[derive(Clone)]
struct AuthData {
access_token: Secret<String>,
expires_at: Instant,
}
#[derive(serde::Deserialize)]
struct TokenResponse {
pub access_token: Secret<String>,
pub expires_in: i64,
}
pub struct Config {
pub token_url: Url,
pub client_id: String,
pub client_secret: Secret<String>,
}
#[async_trait]
impl KeycloakApi for Api {
async fn search_users_by_username(
&self,
search_params: &UserSearchParameters,
pagination: &PageParameters,
client: &reqwest::Client,
) -> Result<Vec<UserDto>> {
let page = pagination
.page
.map_or(MIN_PAGE, |v| if v < MIN_PAGE { MIN_PAGE } else { v });
let per_page = pagination.per_page.map_or(DEFAULT_PER_PAGE, |v| {
if v < MIN_PER_PAGE {
MIN_PER_PAGE
} else {
v
}
});
let first = (page - 1) * per_page;
let mut url = self.make_url("/users");
url.query_pairs_mut()
.append_pair("username", &search_params.username)
.append_pair("first", &format!("{first}"))
.append_pair("max", &format!("{per_page}"))
.append_pair("briefRepresentation", "true")
.append_pair("enabled", "true");
self.get::<Vec<UserDto>>(client, url).await
}
async fn get_users_by_ids(
&self,
client: &reqwest::Client,
user_ids: Vec<uuid::Uuid>,
) -> Result<Vec<UserDto>> {
let future_stream = stream::iter(user_ids)
.map(|id| {
let client = client.clone();
let api = self.clone();
tokio::spawn(async move { api.get_user_by_id(&client, id).await })
})
.buffered(10); let users = future_stream
.map(|res| match res {
Ok(Ok(user)) => Ok(user),
Ok(Err(e)) => Err(e),
Err(e) => Err(KeycloakApiError::Other(e.to_string())),
})
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(users)
}
async fn get_user_by_id(
&self,
client: &reqwest::Client,
user_id: uuid::Uuid,
) -> Result<UserDto> {
let url = self.make_url(&format!("/users/{user_id}"));
self.get::<UserDto>(client, url).await
}
}
impl Api {
#[allow(clippy::expect_used)]
#[must_use]
pub fn new(config: Config) -> Self {
let Config {
token_url,
client_id,
client_secret,
} = config;
let mut base_url = to_base_url(token_url.clone());
base_url.set_path("admin/realms/PermaplanT");
Self {
base_url,
token_url,
client_id,
client_secret,
auth_data: Arc::new(Mutex::new(None)),
}
}
async fn get<T: DeserializeOwned>(&self, client: &reqwest::Client, url: Url) -> Result<T> {
let mut request = reqwest::Request::new(reqwest::Method::GET, url);
let token = self.get_or_refresh_access_token(client).await?;
let token_header =
HeaderValue::from_str(&format!("Bearer {}", token.expose_secret().as_str()))?;
request.headers_mut().append("Authorization", token_header);
let response = client.execute(request).await?;
if let Err(err) = response.error_for_status_ref() {
return Err(KeycloakApiError::Reqwest(err.to_string()));
}
let response_text = response.text().await?;
Ok(serde_json::from_str(&response_text)?)
}
async fn get_or_refresh_access_token(
&self,
client: &reqwest::Client,
) -> Result<Secret<String>> {
let mut guard = self.auth_data.lock().await;
match &*guard {
Some(AuthData {
access_token,
expires_at,
}) if *expires_at > Instant::now() => Ok(access_token.clone()),
_ => {
let new_auth_data = self.refresh_access_token(client).await?;
*guard = Some(new_auth_data.clone());
drop(guard);
Ok(new_auth_data.access_token)
}
}
}
async fn refresh_access_token(&self, client: &reqwest::Client) -> Result<AuthData> {
let token_response = client
.post(self.token_url.clone())
.form(&[
("grant_type", "client_credentials"),
("client_id", &self.client_id),
("client_secret", self.client_secret.expose_secret().as_str()),
])
.send()
.await?
.json::<TokenResponse>()
.await?;
let access_token = token_response.access_token.clone();
let expires_at =
Instant::now() + Duration::seconds(token_response.expires_in) - Duration::seconds(5);
Ok(AuthData {
access_token,
expires_at,
})
}
fn make_url(&self, path: &str) -> url::Url {
let mut url = self.base_url.clone();
url.set_path(&format!("{}{}", self.base_url.path(), path));
url
}
}
fn to_base_url(mut url: Url) -> Url {
url.path_segments_mut().map_or_else(
|()| panic!("Cannot set base url"),
|mut segments| {
segments.clear();
},
);
url.set_query(None);
url
}