diesel_async/
transaction_manager.rs1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4    InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use crate::AsyncConnection;
15#[async_trait::async_trait]
22pub trait TransactionManager<Conn: AsyncConnection>: Send {
23    type TransactionStateData;
26
27    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()>;
33
34    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>;
40
41    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()>;
47
48    #[doc(hidden)]
54    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
55
56    async fn transaction<'a, F, R, E>(conn: &mut Conn, callback: F) -> Result<R, E>
61    where
62        F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
63        E: From<Error> + Send,
64        R: Send,
65    {
66        Self::begin_transaction(conn).await?;
67        match callback(&mut *conn).await {
68            Ok(value) => {
69                Self::commit_transaction(conn).await?;
70                Ok(value)
71            }
72            Err(user_error) => match Self::rollback_transaction(conn).await {
73                Ok(()) => Err(user_error),
74                Err(Error::BrokenTransactionManager) => {
75                    Err(user_error)
78                }
79                Err(rollback_error) => Err(rollback_error.into()),
80            },
81        }
82    }
83
84    #[doc(hidden)]
92    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
93        check_broken_transaction_state(conn)
94    }
95}
96
97fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
98where
99    Conn: AsyncConnection,
100{
101    match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
102        Ok(ValidTransactionManagerStatus {
105            in_transaction: None,
106            ..
107        }) => false,
108        Err(_) => true,
111        Ok(ValidTransactionManagerStatus {
115            in_transaction: Some(s),
116            ..
117        }) => !s.test_transaction,
118    }
119}
120
121#[derive(Default, Debug)]
124pub struct AnsiTransactionManager {
125    pub(crate) status: TransactionManagerStatus,
126    pub(crate) is_broken: Arc<AtomicBool>,
141}
142
143impl AnsiTransactionManager {
144    fn get_transaction_state<Conn>(
145        conn: &mut Conn,
146    ) -> QueryResult<&mut ValidTransactionManagerStatus>
147    where
148        Conn: AsyncConnection<TransactionManager = Self>,
149    {
150        conn.transaction_state().status.transaction_state()
151    }
152
153    pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
159    where
160        Conn: AsyncConnection<TransactionManager = Self>,
161    {
162        let is_broken = conn.transaction_state().is_broken.clone();
163        let state = Self::get_transaction_state(conn)?;
164        match state.transaction_depth() {
165            None => {
166                Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
167                Self::get_transaction_state(conn)?
168                    .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
169                Ok(())
170            }
171            Some(_depth) => Err(Error::AlreadyInTransaction),
172        }
173    }
174
175    async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
182    where
183        F: std::future::Future,
184    {
185        let was_broken = is_broken.swap(true, Ordering::Relaxed);
186        debug_assert!(
187            !was_broken,
188            "Tried to execute a transaction SQL on transaction manager that was previously cancled"
189        );
190        let res = f.await;
191        is_broken.store(false, Ordering::Relaxed);
192        res
193    }
194}
195
196#[async_trait::async_trait]
197impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
198where
199    Conn: AsyncConnection<TransactionManager = Self>,
200{
201    type TransactionStateData = Self;
202
203    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
204        let transaction_state = Self::get_transaction_state(conn)?;
205        let start_transaction_sql = match transaction_state.transaction_depth() {
206            None => Cow::from("BEGIN"),
207            Some(transaction_depth) => {
208                Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
209            }
210        };
211        let depth = transaction_state
212            .transaction_depth()
213            .and_then(|d| d.checked_add(1))
214            .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
215        conn.instrumentation()
216            .on_connection_event(InstrumentationEvent::begin_transaction(depth));
217        Self::critical_transaction_block(
218            &conn.transaction_state().is_broken.clone(),
219            conn.batch_execute(&start_transaction_sql),
220        )
221        .await?;
222        Self::get_transaction_state(conn)?
223            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
224
225        Ok(())
226    }
227
228    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
229        let transaction_state = Self::get_transaction_state(conn)?;
230
231        let (
232            (rollback_sql, rolling_back_top_level),
233            requires_rollback_maybe_up_to_top_level_before_execute,
234        ) = match transaction_state.in_transaction {
235            Some(ref in_transaction) => (
236                match in_transaction.transaction_depth.get() {
237                    1 => (Cow::Borrowed("ROLLBACK"), true),
238                    depth_gt1 => (
239                        Cow::Owned(format!(
240                            "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
241                            depth_gt1 - 1
242                        )),
243                        false,
244                    ),
245                },
246                in_transaction.requires_rollback_maybe_up_to_top_level,
247            ),
248            None => return Err(Error::NotInTransaction),
249        };
250
251        let depth = transaction_state
252            .transaction_depth()
253            .expect("We know that we are in a transaction here");
254        conn.instrumentation()
255            .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
256
257        let is_broken = conn.transaction_state().is_broken.clone();
258
259        match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
260        {
261            Ok(()) => {
262                match Self::get_transaction_state(conn)?
263                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
264                {
265                    Ok(()) => {}
266                    Err(Error::NotInTransaction) if rolling_back_top_level => {
267                        }
270                    Err(e) => return Err(e),
271                }
272                Ok(())
273            }
274            Err(rollback_error) => {
275                let tm_status = Self::transaction_manager_status_mut(conn);
276                match tm_status {
277                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
278                        in_transaction:
279                            Some(InTransactionStatus {
280                                transaction_depth,
281                                requires_rollback_maybe_up_to_top_level,
282                                ..
283                            }),
284                        ..
285                    }) if transaction_depth.get() > 1 => {
286                        *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
294                            .expect("Depth was checked to be > 1");
295                        *requires_rollback_maybe_up_to_top_level = true;
296                        if requires_rollback_maybe_up_to_top_level_before_execute {
297                            return Ok(());
300                        }
301                    }
302                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
303                        in_transaction: None,
304                        ..
305                    }) => {
306                        }
311                    _ => tm_status.set_in_error(),
312                }
313                Err(rollback_error)
314            }
315        }
316    }
317
318    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
324        let transaction_state = Self::get_transaction_state(conn)?;
325        let transaction_depth = transaction_state.transaction_depth();
326        let (commit_sql, committing_top_level) = match transaction_depth {
327            None => return Err(Error::NotInTransaction),
328            Some(transaction_depth) if transaction_depth.get() == 1 => {
329                (Cow::Borrowed("COMMIT"), true)
330            }
331            Some(transaction_depth) => (
332                Cow::Owned(format!(
333                    "RELEASE SAVEPOINT diesel_savepoint_{}",
334                    transaction_depth.get() - 1
335                )),
336                false,
337            ),
338        };
339        let depth = transaction_state
340            .transaction_depth()
341            .expect("We know that we are in a transaction here");
342        conn.instrumentation()
343            .on_connection_event(InstrumentationEvent::commit_transaction(depth));
344
345        let is_broken = conn.transaction_state().is_broken.clone();
346
347        match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
348            Ok(()) => {
349                match Self::get_transaction_state(conn)?
350                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
351                {
352                    Ok(()) => {}
353                    Err(Error::NotInTransaction) if committing_top_level => {
354                        }
357                    Err(e) => return Err(e),
358                }
359                Ok(())
360            }
361            Err(commit_error) => {
362                if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
363                    in_transaction:
364                        Some(InTransactionStatus {
365                            requires_rollback_maybe_up_to_top_level: true,
366                            ..
367                        }),
368                    ..
369                }) = conn.transaction_state().status
370                {
371                    match Self::critical_transaction_block(
372                        &is_broken,
373                        Self::rollback_transaction(conn),
374                    )
375                    .await
376                    {
377                        Ok(()) => {}
378                        Err(rollback_error) => {
379                            conn.transaction_state().status.set_in_error();
380                            return Err(Error::RollbackErrorOnCommit {
381                                rollback_error: Box::new(rollback_error),
382                                commit_error: Box::new(commit_error),
383                            });
384                        }
385                    }
386                }
387                Err(commit_error)
388            }
389        }
390    }
391
392    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
393        &mut conn.transaction_state().status
394    }
395
396    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
397        conn.transaction_state().is_broken.load(Ordering::Relaxed)
398            || check_broken_transaction_state(conn)
399    }
400}