1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
//! Keycloak API
//!
//! This module contains the implementation of the client for the keycloak admin API.
//! Find the documentation here of the keycloak admin API here:
//! <https://www.keycloak.org/docs-api/22.0.1/rest-api/index.html#_users/>

use std::sync::Arc;
use std::time::Instant;

use actix_web::cookie::time::Duration;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
use reqwest::header::HeaderValue;
use reqwest::Url;
use secrecy::{ExposeSecret, Secret};
use serde::de::DeserializeOwned;
use tokio::sync::Mutex;

use crate::keycloak_api::dtos::UserDto;
use crate::model::dto::{PageParameters, UserSearchParameters};

use super::traits::KeycloakApi;
use super::{errors::KeycloakApiError, traits::Result};

/// The default number of rows returned from a paginated request.
pub const DEFAULT_PER_PAGE: i32 = 10;
/// The minimum value for page number in a paginated request.
/// Pages start at 1. Using a lower value would lead to nonsensical queries.
pub const MIN_PAGE: i32 = 1;
/// The minimum number of rows returned from a paginated query.
pub const MIN_PER_PAGE: i32 = 1;

/// The keycloak admin API.
#[derive(Clone)]
pub struct Api {
    /// Base url for the Keycloak admin REST API.
    base_url: Url,
    /// Cached access token (needs to be thread safe).
    /// Might be expired, in which case it will be refreshed.
    auth_data: Arc<Mutex<Option<AuthData>>>,
    /// Url for requesting the access token from the auth server.
    token_url: Url,
    /// The client id for the oauth2 client.
    client_id: String,
    /// The client secret for the oauth2 client.
    client_secret: Secret<String>,
}

/// Helper struct to cache the access token and its expiration time.
#[derive(Clone)]
struct AuthData {
    /// The access token.
    access_token: Secret<String>,
    /// Timestamp the token expires.
    expires_at: Instant,
}

/// Helper struct to deserialize the token response.
#[derive(serde::Deserialize)]
struct TokenResponse {
    /// The access token.
    pub access_token: Secret<String>,
    /// Timestamp the token expires.
    pub expires_in: i64,
}

pub struct Config {
    pub token_url: Url,
    pub client_id: String,
    pub client_secret: Secret<String>,
}

#[async_trait]
impl KeycloakApi for Api {
    async fn search_users_by_username(
        &self,
        search_params: &UserSearchParameters,
        pagination: &PageParameters,
        client: &reqwest::Client,
    ) -> Result<Vec<UserDto>> {
        let page = pagination
            .page
            .map_or(MIN_PAGE, |v| if v < MIN_PAGE { MIN_PAGE } else { v });
        let per_page = pagination.per_page.map_or(DEFAULT_PER_PAGE, |v| {
            if v < MIN_PER_PAGE {
                MIN_PER_PAGE
            } else {
                v
            }
        });

        let first = (page - 1) * per_page;

        let mut url = self.make_url("/users");
        url.query_pairs_mut()
            .append_pair("username", &search_params.username)
            .append_pair("first", &format!("{first}"))
            .append_pair("max", &format!("{per_page}"))
            // optimize the response by only requesting the brief representation
            .append_pair("briefRepresentation", "true")
            // only request enabled users
            .append_pair("enabled", "true");
        self.get::<Vec<UserDto>>(client, url).await
    }

    async fn get_users_by_ids(
        &self,
        client: &reqwest::Client,
        user_ids: Vec<uuid::Uuid>,
    ) -> Result<Vec<UserDto>> {
        let future_stream = stream::iter(user_ids)
            .map(|id| {
                let client = client.clone();
                let api = self.clone();
                tokio::spawn(async move { api.get_user_by_id(&client, id).await })
            })
            .buffered(10); // buffered, because we want the users in the order of the ids

        let users = future_stream
            .map(|res| match res {
                Ok(Ok(user)) => Ok(user),
                Ok(Err(e)) => Err(e),
                Err(e) => Err(KeycloakApiError::Other(e.to_string())),
            })
            .collect::<Vec<_>>()
            .await
            .into_iter()
            .collect::<std::result::Result<Vec<_>, _>>()?;

        Ok(users)
    }

    async fn get_user_by_id(
        &self,
        client: &reqwest::Client,
        user_id: uuid::Uuid,
    ) -> Result<UserDto> {
        let url = self.make_url(&format!("/users/{user_id}"));
        self.get::<UserDto>(client, url).await
    }
}

impl Api {
    /// Creates a new Keycloak API.
    ///
    /// # Panics
    /// If the config does not contain a valid keycloak auth URI.
    #[allow(clippy::expect_used)]
    #[must_use]
    pub fn new(config: Config) -> Self {
        let Config {
            token_url,
            client_id,
            client_secret,
        } = config;
        let mut base_url = to_base_url(token_url.clone());
        base_url.set_path("admin/realms/PermaplanT");

        Self {
            base_url,
            token_url,
            client_id,
            client_secret,
            auth_data: Arc::new(Mutex::new(None)),
        }
    }

    /// Executes a get request authenticated with the access token.
    async fn get<T: DeserializeOwned>(&self, client: &reqwest::Client, url: Url) -> Result<T> {
        let mut request = reqwest::Request::new(reqwest::Method::GET, url);
        let token = self.get_or_refresh_access_token(client).await?;
        let token_header =
            HeaderValue::from_str(&format!("Bearer {}", token.expose_secret().as_str()))?;
        request.headers_mut().append("Authorization", token_header);

        // There is a `json` method, but then we can not log the response for debugging easily.
        let response = client.execute(request).await?;

        if let Err(err) = response.error_for_status_ref() {
            return Err(KeycloakApiError::Reqwest(err.to_string()));
        }

        let response_text = response.text().await?;
        Ok(serde_json::from_str(&response_text)?)
    }

    /// Gets the access token or refreshes it if it is expired.
    async fn get_or_refresh_access_token(
        &self,
        client: &reqwest::Client,
    ) -> Result<Secret<String>> {
        let mut guard = self.auth_data.lock().await;

        match &*guard {
            Some(AuthData {
                access_token,
                expires_at,
            }) if *expires_at > Instant::now() => Ok(access_token.clone()),
            _ => {
                let new_auth_data = self.refresh_access_token(client).await?;

                *guard = Some(new_auth_data.clone());
                drop(guard);

                Ok(new_auth_data.access_token)
            }
        }
    }

    /// Refresh the access token.
    async fn refresh_access_token(&self, client: &reqwest::Client) -> Result<AuthData> {
        let token_response = client
            .post(self.token_url.clone())
            .form(&[
                ("grant_type", "client_credentials"),
                ("client_id", &self.client_id),
                ("client_secret", self.client_secret.expose_secret().as_str()),
            ])
            .send()
            .await?
            .json::<TokenResponse>()
            .await?;

        let access_token = token_response.access_token.clone();
        let expires_at =
            Instant::now() + Duration::seconds(token_response.expires_in) - Duration::seconds(5);

        Ok(AuthData {
            access_token,
            expires_at,
        })
    }

    /// Creates a URL from the base URL and the given path.
    fn make_url(&self, path: &str) -> url::Url {
        let mut url = self.base_url.clone();
        url.set_path(&format!("{}{}", self.base_url.path(), path));
        url
    }
}

/// Helper function to create a base URL.
///
/// # Panics
/// If the URL cannot be a base URL.
fn to_base_url(mut url: Url) -> Url {
    url.path_segments_mut().map_or_else(
        |()| panic!("Cannot set base url"),
        |mut segments| {
            segments.clear();
        },
    );

    url.set_query(None);
    url
}