diesel_async/
transaction_manager.rs

1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4    InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::num::NonZeroU32;
11
12use crate::AsyncConnection;
13// TODO: refactor this to share more code with diesel
14
15/// Manages the internal transaction state for a connection.
16///
17/// You will not need to interact with this trait, unless you are writing an
18/// implementation of [`AsyncConnection`].
19#[async_trait::async_trait]
20pub trait TransactionManager<Conn: AsyncConnection>: Send {
21    /// Data stored as part of the connection implementation
22    /// to track the current transaction state of a connection
23    type TransactionStateData;
24
25    /// Begin a new transaction or savepoint
26    ///
27    /// If the transaction depth is greater than 0,
28    /// this should create a savepoint instead.
29    /// This function is expected to increment the transaction depth by 1.
30    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()>;
31
32    /// Rollback the inner-most transaction or savepoint
33    ///
34    /// If the transaction depth is greater than 1,
35    /// this should rollback to the most recent savepoint.
36    /// This function is expected to decrement the transaction depth by 1.
37    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>;
38
39    /// Commit the inner-most transaction or savepoint
40    ///
41    /// If the transaction depth is greater than 1,
42    /// this should release the most recent savepoint.
43    /// This function is expected to decrement the transaction depth by 1.
44    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()>;
45
46    /// Fetch the current transaction status as mutable
47    ///
48    /// Used to ensure that `begin_test_transaction` is not called when already
49    /// inside of a transaction, and that operations are not run in a `InError`
50    /// transaction manager.
51    #[doc(hidden)]
52    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
53
54    /// Executes the given function inside of a database transaction
55    ///
56    /// Each implementation of this function needs to fulfill the documented
57    /// behaviour of [`AsyncConnection::transaction`]
58    async fn transaction<'a, F, R, E>(conn: &mut Conn, callback: F) -> Result<R, E>
59    where
60        F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
61        E: From<Error> + Send,
62        R: Send,
63    {
64        Self::begin_transaction(conn).await?;
65        match callback(&mut *conn).await {
66            Ok(value) => {
67                Self::commit_transaction(conn).await?;
68                Ok(value)
69            }
70            Err(user_error) => match Self::rollback_transaction(conn).await {
71                Ok(()) => Err(user_error),
72                Err(Error::BrokenTransactionManager) => {
73                    // In this case we are probably more interested by the
74                    // original error, which likely caused this
75                    Err(user_error)
76                }
77                Err(rollback_error) => Err(rollback_error.into()),
78            },
79        }
80    }
81
82    /// This methods checks if the connection manager is considered to be broken
83    /// by connection pool implementations
84    ///
85    /// A connection manager is considered to be broken by default if it either
86    /// contains an open transaction (because you don't want to have connections
87    /// with open transactions in your pool) or when the transaction manager is
88    /// in an error state.
89    #[doc(hidden)]
90    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
91        match Self::transaction_manager_status_mut(conn).transaction_state() {
92            // all transactions are closed
93            // so we don't consider this connection broken
94            Ok(ValidTransactionManagerStatus {
95                in_transaction: None,
96                ..
97            }) => false,
98            // The transaction manager is in an error state
99            // Therefore we consider this connection broken
100            Err(_) => true,
101            // The transaction manager contains a open transaction
102            // we do consider this connection broken
103            // if that transaction was not opened by `begin_test_transaction`
104            Ok(ValidTransactionManagerStatus {
105                in_transaction: Some(s),
106                ..
107            }) => !s.test_transaction,
108        }
109    }
110}
111
112/// An implementation of `TransactionManager` which can be used for backends
113/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
114#[derive(Default, Debug)]
115pub struct AnsiTransactionManager {
116    pub(crate) status: TransactionManagerStatus,
117}
118
119// /// Status of the transaction manager
120// #[derive(Debug)]
121// pub enum TransactionManagerStatus {
122//     /// Valid status, the manager can run operations
123//     Valid(ValidTransactionManagerStatus),
124//     /// Error status, probably following a broken connection. The manager will no longer run operations
125//     InError,
126// }
127
128// impl Default for TransactionManagerStatus {
129//     fn default() -> Self {
130//         TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
131//     }
132// }
133
134// impl TransactionManagerStatus {
135//     /// Returns the transaction depth if the transaction manager's status is valid, or returns
136//     /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
137//     pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
138//         match self {
139//             TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
140//             TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
141//         }
142//     }
143
144//     /// If in transaction and transaction manager is not broken, registers that the
145//     /// connection can not be used anymore until top-level transaction is rolled back
146//     pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
147//         if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
148//             in_transaction:
149//                 Some(InTransactionStatus {
150//                     top_level_transaction_requires_rollback,
151//                     ..
152//                 }),
153//         }) = self
154//         {
155//             *top_level_transaction_requires_rollback = true;
156//         }
157//     }
158
159//     /// Sets the transaction manager status to InError
160//     ///
161//     /// Subsequent attempts to use transaction-related features will result in a
162//     /// [`Error::BrokenTransactionManager`] error
163//     pub fn set_in_error(&mut self) {
164//         *self = TransactionManagerStatus::InError
165//     }
166
167//     fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
168//         match self {
169//             TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
170//             TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
171//         }
172//     }
173
174//     pub(crate) fn set_test_transaction_flag(&mut self) {
175//         if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
176//             in_transaction: Some(s),
177//         }) = self
178//         {
179//             s.test_transaction = true;
180//         }
181//     }
182// }
183
184// /// Valid transaction status for the manager. Can return the current transaction depth
185// #[allow(missing_copy_implementations)]
186// #[derive(Debug, Default)]
187// pub struct ValidTransactionManagerStatus {
188//     in_transaction: Option<InTransactionStatus>,
189// }
190
191// #[allow(missing_copy_implementations)]
192// #[derive(Debug)]
193// struct InTransactionStatus {
194//     transaction_depth: NonZeroU32,
195//     top_level_transaction_requires_rollback: bool,
196//     test_transaction: bool,
197// }
198
199// impl ValidTransactionManagerStatus {
200//     /// Return the current transaction depth
201//     ///
202//     /// This value is `None` if no current transaction is running
203//     /// otherwise the number of nested transactions is returned.
204//     pub fn transaction_depth(&self) -> Option<NonZeroU32> {
205//         self.in_transaction.as_ref().map(|it| it.transaction_depth)
206//     }
207
208//     /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
209//     /// `Ok(())`
210//     pub fn change_transaction_depth(
211//         &mut self,
212//         transaction_depth_change: TransactionDepthChange,
213//     ) -> QueryResult<()> {
214//         match (&mut self.in_transaction, transaction_depth_change) {
215//             (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
216//                 // Can be replaced with saturating_add directly on NonZeroU32 once
217//                 // <https://github.com/rust-lang/rust/issues/84186> is stable
218//                 in_transaction.transaction_depth =
219//                     NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
220//                         .expect("nz + nz is always non-zero");
221//                 Ok(())
222//             }
223//             (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
224//                 // This sets `transaction_depth` to `None` as soon as we reach zero
225//                 match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
226//                     Some(depth) => in_transaction.transaction_depth = depth,
227//                     None => self.in_transaction = None,
228//                 }
229//                 Ok(())
230//             }
231//             (None, TransactionDepthChange::IncreaseDepth) => {
232//                 self.in_transaction = Some(InTransactionStatus {
233//                     transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
234//                     top_level_transaction_requires_rollback: false,
235//                     test_transaction: false,
236//                 });
237//                 Ok(())
238//             }
239//             (None, TransactionDepthChange::DecreaseDepth) => {
240//                 // We screwed up something somewhere
241//                 // we cannot decrease the transaction count if
242//                 // we are not inside a transaction
243//                 Err(Error::NotInTransaction)
244//             }
245//         }
246//     }
247// }
248
249// /// Represents a change to apply to the depth of a transaction
250// #[derive(Debug, Clone, Copy)]
251// pub enum TransactionDepthChange {
252//     /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
253//     IncreaseDepth,
254//     /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
255//     DecreaseDepth,
256// }
257
258impl AnsiTransactionManager {
259    fn get_transaction_state<Conn>(
260        conn: &mut Conn,
261    ) -> QueryResult<&mut ValidTransactionManagerStatus>
262    where
263        Conn: AsyncConnection<TransactionManager = Self>,
264    {
265        conn.transaction_state().status.transaction_state()
266    }
267
268    /// Begin a transaction with custom SQL
269    ///
270    /// This is used by connections to implement more complex transaction APIs
271    /// to set things such as isolation levels.
272    /// Returns an error if already inside of a transaction.
273    pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
274    where
275        Conn: AsyncConnection<TransactionManager = Self>,
276    {
277        let state = Self::get_transaction_state(conn)?;
278        match state.transaction_depth() {
279            None => {
280                conn.batch_execute(sql).await?;
281                Self::get_transaction_state(conn)?
282                    .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
283                Ok(())
284            }
285            Some(_depth) => Err(Error::AlreadyInTransaction),
286        }
287    }
288}
289
290#[async_trait::async_trait]
291impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
292where
293    Conn: AsyncConnection<TransactionManager = Self>,
294{
295    type TransactionStateData = Self;
296
297    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
298        let transaction_state = Self::get_transaction_state(conn)?;
299        let start_transaction_sql = match transaction_state.transaction_depth() {
300            None => Cow::from("BEGIN"),
301            Some(transaction_depth) => {
302                Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
303            }
304        };
305        let depth = transaction_state
306            .transaction_depth()
307            .and_then(|d| d.checked_add(1))
308            .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
309        conn.instrumentation()
310            .on_connection_event(InstrumentationEvent::begin_transaction(depth));
311        conn.batch_execute(&start_transaction_sql).await?;
312        Self::get_transaction_state(conn)?
313            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
314
315        Ok(())
316    }
317
318    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
319        let transaction_state = Self::get_transaction_state(conn)?;
320
321        let (
322            (rollback_sql, rolling_back_top_level),
323            requires_rollback_maybe_up_to_top_level_before_execute,
324        ) = match transaction_state.in_transaction {
325            Some(ref in_transaction) => (
326                match in_transaction.transaction_depth.get() {
327                    1 => (Cow::Borrowed("ROLLBACK"), true),
328                    depth_gt1 => (
329                        Cow::Owned(format!(
330                            "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
331                            depth_gt1 - 1
332                        )),
333                        false,
334                    ),
335                },
336                in_transaction.requires_rollback_maybe_up_to_top_level,
337            ),
338            None => return Err(Error::NotInTransaction),
339        };
340
341        let depth = transaction_state
342            .transaction_depth()
343            .expect("We know that we are in a transaction here");
344        conn.instrumentation()
345            .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
346
347        match conn.batch_execute(&rollback_sql).await {
348            Ok(()) => {
349                match Self::get_transaction_state(conn)?
350                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
351                {
352                    Ok(()) => {}
353                    Err(Error::NotInTransaction) if rolling_back_top_level => {
354                        // Transaction exit may have already been detected by connection
355                        // implementation. It's fine.
356                    }
357                    Err(e) => return Err(e),
358                }
359                Ok(())
360            }
361            Err(rollback_error) => {
362                let tm_status = Self::transaction_manager_status_mut(conn);
363                match tm_status {
364                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
365                        in_transaction:
366                            Some(InTransactionStatus {
367                                transaction_depth,
368                                requires_rollback_maybe_up_to_top_level,
369                                ..
370                            }),
371                        ..
372                    }) if transaction_depth.get() > 1 => {
373                        // A savepoint failed to rollback - we may still attempt to repair
374                        // the connection by rolling back higher levels.
375
376                        // To make it easier on the user (that they don't have to really
377                        // look at actual transaction depth and can just rely on the number
378                        // of times they have called begin/commit/rollback) we still
379                        // decrement here:
380                        *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
381                            .expect("Depth was checked to be > 1");
382                        *requires_rollback_maybe_up_to_top_level = true;
383                        if requires_rollback_maybe_up_to_top_level_before_execute {
384                            // In that case, we tolerate that savepoint releases fail
385                            // -> we should ignore errors
386                            return Ok(());
387                        }
388                    }
389                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
390                        in_transaction: None,
391                        ..
392                    }) => {
393                        // we would have returned `NotInTransaction` if that was already the state
394                        // before we made our call
395                        // => Transaction manager status has been fixed by the underlying connection
396                        // so we don't need to set_in_error
397                    }
398                    _ => tm_status.set_in_error(),
399                }
400                Err(rollback_error)
401            }
402        }
403    }
404
405    /// If the transaction fails to commit due to a `SerializationFailure` or a
406    /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
407    /// the original error will be returned, otherwise the error generated by the rollback
408    /// will be returned. In the second case the connection will be considered broken
409    /// as it contains a uncommitted unabortable open transaction.
410    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
411        let transaction_state = Self::get_transaction_state(conn)?;
412        let transaction_depth = transaction_state.transaction_depth();
413        let (commit_sql, committing_top_level) = match transaction_depth {
414            None => return Err(Error::NotInTransaction),
415            Some(transaction_depth) if transaction_depth.get() == 1 => {
416                (Cow::Borrowed("COMMIT"), true)
417            }
418            Some(transaction_depth) => (
419                Cow::Owned(format!(
420                    "RELEASE SAVEPOINT diesel_savepoint_{}",
421                    transaction_depth.get() - 1
422                )),
423                false,
424            ),
425        };
426        let depth = transaction_state
427            .transaction_depth()
428            .expect("We know that we are in a transaction here");
429        conn.instrumentation()
430            .on_connection_event(InstrumentationEvent::commit_transaction(depth));
431
432        match conn.batch_execute(&commit_sql).await {
433            Ok(()) => {
434                match Self::get_transaction_state(conn)?
435                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
436                {
437                    Ok(()) => {}
438                    Err(Error::NotInTransaction) if committing_top_level => {
439                        // Transaction exit may have already been detected by connection.
440                        // It's fine
441                    }
442                    Err(e) => return Err(e),
443                }
444                Ok(())
445            }
446            Err(commit_error) => {
447                if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
448                    in_transaction:
449                        Some(InTransactionStatus {
450                            requires_rollback_maybe_up_to_top_level: true,
451                            ..
452                        }),
453                    ..
454                }) = conn.transaction_state().status
455                {
456                    match Self::rollback_transaction(conn).await {
457                        Ok(()) => {}
458                        Err(rollback_error) => {
459                            conn.transaction_state().status.set_in_error();
460                            return Err(Error::RollbackErrorOnCommit {
461                                rollback_error: Box::new(rollback_error),
462                                commit_error: Box::new(commit_error),
463                            });
464                        }
465                    }
466                }
467                Err(commit_error)
468            }
469        }
470    }
471
472    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
473        &mut conn.transaction_state().status
474    }
475}