backend/keycloak_api/
api.rs

1//! Keycloak API
2//!
3//! This module contains the implementation of the client for the keycloak admin API.
4//! Find the documentation here of the keycloak admin API here:
5//! <https://www.keycloak.org/docs-api/22.0.1/rest-api/index.html#_users/>
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use actix_web::cookie::time::Duration;
11use async_trait::async_trait;
12use futures_util::{stream, StreamExt};
13use reqwest::header::HeaderValue;
14use reqwest::Url;
15use secrecy::{ExposeSecret, Secret};
16use serde::de::DeserializeOwned;
17use tokio::sync::Mutex;
18
19use crate::keycloak_api::dtos::UserDto;
20use crate::model::dto::{PageParameters, UserSearchParameters};
21
22use super::traits::KeycloakApi;
23use super::{errors::KeycloakApiError, traits::Result};
24
25/// The default number of rows returned from a paginated request.
26pub const DEFAULT_PER_PAGE: i32 = 10;
27/// The minimum value for page number in a paginated request.
28/// Pages start at 1. Using a lower value would lead to nonsensical queries.
29pub const MIN_PAGE: i32 = 1;
30/// The minimum number of rows returned from a paginated query.
31pub const MIN_PER_PAGE: i32 = 1;
32
33/// The keycloak admin API.
34#[derive(Clone)]
35pub struct Api {
36    /// Base url for the Keycloak admin REST API.
37    base_url: Url,
38    /// Cached access token (needs to be thread safe).
39    /// Might be expired, in which case it will be refreshed.
40    auth_data: Arc<Mutex<Option<AuthData>>>,
41    /// Url for requesting the access token from the auth server.
42    token_url: Url,
43    /// The client id for the oauth2 client.
44    client_id: String,
45    /// The client secret for the oauth2 client.
46    client_secret: Secret<String>,
47}
48
49/// Helper struct to cache the access token and its expiration time.
50#[derive(Clone)]
51struct AuthData {
52    /// The access token.
53    access_token: Secret<String>,
54    /// Timestamp the token expires.
55    expires_at: Instant,
56}
57
58/// Helper struct to deserialize the token response.
59#[derive(serde::Deserialize)]
60struct TokenResponse {
61    /// The access token.
62    pub access_token: Secret<String>,
63    /// Timestamp the token expires.
64    pub expires_in: i64,
65}
66
67pub struct Config {
68    pub token_url: Url,
69    pub client_id: String,
70    pub client_secret: Secret<String>,
71}
72
73#[async_trait]
74impl KeycloakApi for Api {
75    async fn search_users_by_username(
76        &self,
77        search_params: &UserSearchParameters,
78        pagination: &PageParameters,
79        client: &reqwest::Client,
80    ) -> Result<Vec<UserDto>> {
81        let page = pagination
82            .page
83            .map_or(MIN_PAGE, |v| if v < MIN_PAGE { MIN_PAGE } else { v });
84        let per_page = pagination.per_page.map_or(DEFAULT_PER_PAGE, |v| {
85            if v < MIN_PER_PAGE {
86                MIN_PER_PAGE
87            } else {
88                v
89            }
90        });
91
92        let first = (page - 1) * per_page;
93
94        let mut url = self.make_url("/users");
95        url.query_pairs_mut()
96            .append_pair("username", &search_params.username)
97            .append_pair("first", &format!("{first}"))
98            .append_pair("max", &format!("{per_page}"))
99            // optimize the response by only requesting the brief representation
100            .append_pair("briefRepresentation", "true")
101            // only request enabled users
102            .append_pair("enabled", "true");
103        self.get::<Vec<UserDto>>(client, url).await
104    }
105
106    async fn get_users_by_ids(
107        &self,
108        client: &reqwest::Client,
109        user_ids: Vec<uuid::Uuid>,
110    ) -> Result<Vec<UserDto>> {
111        let future_stream = stream::iter(user_ids)
112            .map(|id| {
113                let client = client.clone();
114                let api = self.clone();
115                tokio::spawn(async move { api.get_user_by_id(&client, id).await })
116            })
117            .buffered(10); // buffered, because we want the users in the order of the ids
118
119        let users = future_stream
120            .map(|res| match res {
121                Ok(Ok(user)) => Ok(user),
122                Ok(Err(e)) => Err(e),
123                Err(e) => Err(KeycloakApiError::Other(e.to_string())),
124            })
125            .collect::<Vec<_>>()
126            .await
127            .into_iter()
128            .collect::<std::result::Result<Vec<_>, _>>()?;
129
130        Ok(users)
131    }
132
133    async fn get_user_by_id(
134        &self,
135        client: &reqwest::Client,
136        user_id: uuid::Uuid,
137    ) -> Result<UserDto> {
138        let url = self.make_url(&format!("/users/{user_id}"));
139        self.get::<UserDto>(client, url).await
140    }
141}
142
143impl Api {
144    /// Creates a new Keycloak API.
145    ///
146    /// # Panics
147    /// If the config does not contain a valid keycloak auth URI.
148    #[allow(clippy::expect_used)]
149    #[must_use]
150    pub fn new(config: Config) -> Self {
151        let Config {
152            token_url,
153            client_id,
154            client_secret,
155        } = config;
156        let mut base_url = to_base_url(token_url.clone());
157        base_url.set_path("admin/realms/PermaplanT");
158
159        Self {
160            base_url,
161            token_url,
162            client_id,
163            client_secret,
164            auth_data: Arc::new(Mutex::new(None)),
165        }
166    }
167
168    /// Executes a get request authenticated with the access token.
169    async fn get<T: DeserializeOwned>(&self, client: &reqwest::Client, url: Url) -> Result<T> {
170        let mut request = reqwest::Request::new(reqwest::Method::GET, url);
171        let token = self.get_or_refresh_access_token(client).await?;
172        let token_header =
173            HeaderValue::from_str(&format!("Bearer {}", token.expose_secret().as_str()))?;
174        request.headers_mut().append("Authorization", token_header);
175
176        // There is a `json` method, but then we can not log the response for debugging easily.
177        let response = client.execute(request).await?;
178
179        if let Err(err) = response.error_for_status_ref() {
180            return Err(KeycloakApiError::Reqwest(err.to_string()));
181        }
182
183        let response_text = response.text().await?;
184        Ok(serde_json::from_str(&response_text)?)
185    }
186
187    /// Gets the access token or refreshes it if it is expired.
188    async fn get_or_refresh_access_token(
189        &self,
190        client: &reqwest::Client,
191    ) -> Result<Secret<String>> {
192        let mut guard = self.auth_data.lock().await;
193
194        match &*guard {
195            Some(AuthData {
196                access_token,
197                expires_at,
198            }) if *expires_at > Instant::now() => Ok(access_token.clone()),
199            _ => {
200                let new_auth_data = self.refresh_access_token(client).await?;
201
202                *guard = Some(new_auth_data.clone());
203                drop(guard);
204
205                Ok(new_auth_data.access_token)
206            }
207        }
208    }
209
210    /// Refresh the access token.
211    async fn refresh_access_token(&self, client: &reqwest::Client) -> Result<AuthData> {
212        let token_response = client
213            .post(self.token_url.clone())
214            .form(&[
215                ("grant_type", "client_credentials"),
216                ("client_id", &self.client_id),
217                ("client_secret", self.client_secret.expose_secret().as_str()),
218            ])
219            .send()
220            .await?
221            .json::<TokenResponse>()
222            .await?;
223
224        let access_token = token_response.access_token.clone();
225        let expires_at =
226            Instant::now() + Duration::seconds(token_response.expires_in) - Duration::seconds(5);
227
228        Ok(AuthData {
229            access_token,
230            expires_at,
231        })
232    }
233
234    /// Creates a URL from the base URL and the given path.
235    fn make_url(&self, path: &str) -> url::Url {
236        let mut url = self.base_url.clone();
237        url.set_path(&format!("{}{}", self.base_url.path(), path));
238        url
239    }
240}
241
242/// Helper function to create a base URL.
243///
244/// # Panics
245/// If the URL cannot be a base URL.
246fn to_base_url(mut url: Url) -> Url {
247    url.path_segments_mut().map_or_else(
248        |()| panic!("Cannot set base url"),
249        |mut segments| {
250            segments.clear();
251        },
252    );
253
254    url.set_query(None);
255    url
256}