backend/sse/
broadcaster.rs

1//! SSE broadcaster.
2//!
3//! This module contains the Server-Sent Events broadcaster, which is responsible for keeping track of connected clients and broadcasting messages to them.
4//! For broadcasting, the broadcaster takes a `map_id` and an `Action` and broadcasts the action to all clients connected to that map.
5
6#![allow(clippy::significant_drop_tightening)]
7
8use actix_web_lab::sse::{self, ChannelStream, Sse};
9use futures::{future::ready, stream, StreamExt};
10use std::{collections::HashMap, sync::Arc, time::Duration};
11use tokio::{sync::Mutex, time::interval};
12
13use crate::model::dto::actions::Action;
14
15/// Map that clients are connected to.
16#[derive(Debug, Clone)]
17struct ConnectedMap {
18    /// Id of the map that the clients are connected to.
19    map_id: i32,
20    /// List of clients connected to the map.
21    clients: Vec<sse::Sender>,
22}
23
24#[derive(Debug, Clone, Default)]
25/// SSE broadcaster.
26///
27/// Inner `HashMap`:
28/// * Map of `map_id` to a list of connected clients.
29/// * The `map_id` is the id of the map that the client connected to.
30/// * The connected map contains the `map_id` and a list of clients connected to that map.
31pub struct Broadcaster(Arc<Mutex<HashMap<i32, ConnectedMap>>>);
32
33impl Broadcaster {
34    /// Constructs new broadcaster and spawns ping loop.
35    #[must_use]
36    pub fn new() -> Self {
37        let broadcaster = Self::default();
38        Self::spawn_ping(broadcaster.clone());
39        broadcaster
40    }
41
42    /// Pings clients every 10 minutes to see if they are alive and remove them from the broadcast list if not.
43    fn spawn_ping(self) {
44        actix_web::rt::spawn(async move {
45            let mut interval = interval(Duration::from_secs(600));
46            loop {
47                interval.tick().await;
48                self.clone().remove_stale_clients().await;
49            }
50        });
51    }
52
53    /// Removes all non-responsive clients from broadcast list.
54    /// TODO: this is a naive implementation, we should probably use a better data structure for this.
55    ///       Things to consider:
56    ///        - how can we do this without having to iterate over all clients?
57    async fn remove_stale_clients(&self) {
58        let mut guard = self.0.lock().await;
59
60        let mut ok_maps = HashMap::with_capacity(guard.capacity());
61
62        stream::iter(guard.values())
63            .map(|map| async move {
64                (
65                    map,
66                    stream::iter(&map.clients)
67                        .filter(|client| async {
68                            client
69                                .send(sse::Event::Comment("ping".into()))
70                                .await
71                                .is_ok()
72                        })
73                        .map(|client| ready(client.clone()))
74                        .buffer_unordered(15)
75                        .collect::<Vec<_>>()
76                        .await,
77                )
78            })
79            .buffer_unordered(100)
80            .filter(|(_, ok_clients)| ready(!ok_clients.is_empty()))
81            .for_each(|(map, ok_clients)| {
82                ok_maps.insert(
83                    map.map_id,
84                    ConnectedMap {
85                        map_id: map.map_id,
86                        clients: ok_clients,
87                    },
88                );
89                ready(())
90            })
91            .await;
92
93        *guard = ok_maps;
94    }
95
96    /// Registers client with broadcaster, returning an SSE response body.
97    ///
98    /// # Errors
99    /// * If `sender.send()` fails for the new client.
100    pub async fn new_client(
101        &self,
102        map_id: i32,
103    ) -> Result<Sse<ChannelStream>, Box<dyn std::error::Error>> {
104        let (sender, channel_stream) = sse::channel(100);
105        let mut guard = self.0.lock().await;
106
107        let map = guard.entry(map_id).or_insert_with(|| ConnectedMap {
108            map_id,
109            clients: Vec::new(),
110        });
111
112        sender.send(sse::Data::new("connected")).await?;
113
114        map.clients.push(sender);
115
116        Ok(channel_stream)
117    }
118
119    /// Broadcasts `msg` to all clients on the same map.
120    pub async fn broadcast(&self, map_id: i32, action: Action) {
121        let action_id = action.action_id.to_string();
122
123        match sse::Data::new_json(action) {
124            Ok(mut serialized_action) => {
125                let guard = self.0.lock().await;
126
127                serialized_action.set_id(action_id);
128
129                if let Some(map) = guard.get(&map_id) {
130                    // try to send to all clients, ignoring failures
131                    // disconnected clients will get swept up by `remove_stale_clients`
132                    let _ = stream::iter(&map.clients)
133                        .map(|client| client.send(serialized_action.clone()))
134                        .buffer_unordered(15)
135                        .collect::<Vec<_>>()
136                        .await;
137                }
138            }
139            Err(err) => {
140                // log the error and continue
141                // serialization errors are also highly unlikely to happen
142                log::error!("{err}");
143            }
144        }
145    }
146
147    /// Broadcasts `msg` to all clients on all maps.
148    pub async fn broadcast_all_maps(&self, action: Action) {
149        let action_id = action.action_id.to_string();
150
151        match sse::Data::new_json(action) {
152            Ok(mut serialized_action) => {
153                let guard = self.0.lock().await;
154
155                serialized_action.set_id(action_id);
156
157                let values = guard.values();
158                for map in values {
159                    // try to send to all clients, ignoring failures
160                    // disconnected clients will get swept up by `remove_stale_clients`
161                    let _ = stream::iter(&map.clients)
162                        .map(|client| client.send(serialized_action.clone()))
163                        .buffer_unordered(15)
164                        .collect::<Vec<_>>()
165                        .await;
166                }
167            }
168            Err(err) => {
169                // log the error and continue
170                // serialization errors are also highly unlikely to happen
171                log::error!("{err}");
172            }
173        }
174    }
175}