diesel_async/transaction_manager.rs
1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4 InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::num::NonZeroU32;
11
12use crate::AsyncConnection;
13// TODO: refactor this to share more code with diesel
14
15/// Manages the internal transaction state for a connection.
16///
17/// You will not need to interact with this trait, unless you are writing an
18/// implementation of [`AsyncConnection`].
19#[async_trait::async_trait]
20pub trait TransactionManager<Conn: AsyncConnection>: Send {
21 /// Data stored as part of the connection implementation
22 /// to track the current transaction state of a connection
23 type TransactionStateData;
24
25 /// Begin a new transaction or savepoint
26 ///
27 /// If the transaction depth is greater than 0,
28 /// this should create a savepoint instead.
29 /// This function is expected to increment the transaction depth by 1.
30 async fn begin_transaction(conn: &mut Conn) -> QueryResult<()>;
31
32 /// Rollback the inner-most transaction or savepoint
33 ///
34 /// If the transaction depth is greater than 1,
35 /// this should rollback to the most recent savepoint.
36 /// This function is expected to decrement the transaction depth by 1.
37 async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>;
38
39 /// Commit the inner-most transaction or savepoint
40 ///
41 /// If the transaction depth is greater than 1,
42 /// this should release the most recent savepoint.
43 /// This function is expected to decrement the transaction depth by 1.
44 async fn commit_transaction(conn: &mut Conn) -> QueryResult<()>;
45
46 /// Fetch the current transaction status as mutable
47 ///
48 /// Used to ensure that `begin_test_transaction` is not called when already
49 /// inside of a transaction, and that operations are not run in a `InError`
50 /// transaction manager.
51 #[doc(hidden)]
52 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
53
54 /// Executes the given function inside of a database transaction
55 ///
56 /// Each implementation of this function needs to fulfill the documented
57 /// behaviour of [`AsyncConnection::transaction`]
58 async fn transaction<'a, F, R, E>(conn: &mut Conn, callback: F) -> Result<R, E>
59 where
60 F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
61 E: From<Error> + Send,
62 R: Send,
63 {
64 Self::begin_transaction(conn).await?;
65 match callback(&mut *conn).await {
66 Ok(value) => {
67 Self::commit_transaction(conn).await?;
68 Ok(value)
69 }
70 Err(user_error) => match Self::rollback_transaction(conn).await {
71 Ok(()) => Err(user_error),
72 Err(Error::BrokenTransactionManager) => {
73 // In this case we are probably more interested by the
74 // original error, which likely caused this
75 Err(user_error)
76 }
77 Err(rollback_error) => Err(rollback_error.into()),
78 },
79 }
80 }
81
82 /// This methods checks if the connection manager is considered to be broken
83 /// by connection pool implementations
84 ///
85 /// A connection manager is considered to be broken by default if it either
86 /// contains an open transaction (because you don't want to have connections
87 /// with open transactions in your pool) or when the transaction manager is
88 /// in an error state.
89 #[doc(hidden)]
90 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
91 match Self::transaction_manager_status_mut(conn).transaction_state() {
92 // all transactions are closed
93 // so we don't consider this connection broken
94 Ok(ValidTransactionManagerStatus {
95 in_transaction: None,
96 ..
97 }) => false,
98 // The transaction manager is in an error state
99 // Therefore we consider this connection broken
100 Err(_) => true,
101 // The transaction manager contains a open transaction
102 // we do consider this connection broken
103 // if that transaction was not opened by `begin_test_transaction`
104 Ok(ValidTransactionManagerStatus {
105 in_transaction: Some(s),
106 ..
107 }) => !s.test_transaction,
108 }
109 }
110}
111
112/// An implementation of `TransactionManager` which can be used for backends
113/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
114#[derive(Default, Debug)]
115pub struct AnsiTransactionManager {
116 pub(crate) status: TransactionManagerStatus,
117}
118
119// /// Status of the transaction manager
120// #[derive(Debug)]
121// pub enum TransactionManagerStatus {
122// /// Valid status, the manager can run operations
123// Valid(ValidTransactionManagerStatus),
124// /// Error status, probably following a broken connection. The manager will no longer run operations
125// InError,
126// }
127
128// impl Default for TransactionManagerStatus {
129// fn default() -> Self {
130// TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
131// }
132// }
133
134// impl TransactionManagerStatus {
135// /// Returns the transaction depth if the transaction manager's status is valid, or returns
136// /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
137// pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
138// match self {
139// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
140// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
141// }
142// }
143
144// /// If in transaction and transaction manager is not broken, registers that the
145// /// connection can not be used anymore until top-level transaction is rolled back
146// pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
147// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
148// in_transaction:
149// Some(InTransactionStatus {
150// top_level_transaction_requires_rollback,
151// ..
152// }),
153// }) = self
154// {
155// *top_level_transaction_requires_rollback = true;
156// }
157// }
158
159// /// Sets the transaction manager status to InError
160// ///
161// /// Subsequent attempts to use transaction-related features will result in a
162// /// [`Error::BrokenTransactionManager`] error
163// pub fn set_in_error(&mut self) {
164// *self = TransactionManagerStatus::InError
165// }
166
167// fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
168// match self {
169// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
170// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
171// }
172// }
173
174// pub(crate) fn set_test_transaction_flag(&mut self) {
175// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
176// in_transaction: Some(s),
177// }) = self
178// {
179// s.test_transaction = true;
180// }
181// }
182// }
183
184// /// Valid transaction status for the manager. Can return the current transaction depth
185// #[allow(missing_copy_implementations)]
186// #[derive(Debug, Default)]
187// pub struct ValidTransactionManagerStatus {
188// in_transaction: Option<InTransactionStatus>,
189// }
190
191// #[allow(missing_copy_implementations)]
192// #[derive(Debug)]
193// struct InTransactionStatus {
194// transaction_depth: NonZeroU32,
195// top_level_transaction_requires_rollback: bool,
196// test_transaction: bool,
197// }
198
199// impl ValidTransactionManagerStatus {
200// /// Return the current transaction depth
201// ///
202// /// This value is `None` if no current transaction is running
203// /// otherwise the number of nested transactions is returned.
204// pub fn transaction_depth(&self) -> Option<NonZeroU32> {
205// self.in_transaction.as_ref().map(|it| it.transaction_depth)
206// }
207
208// /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
209// /// `Ok(())`
210// pub fn change_transaction_depth(
211// &mut self,
212// transaction_depth_change: TransactionDepthChange,
213// ) -> QueryResult<()> {
214// match (&mut self.in_transaction, transaction_depth_change) {
215// (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
216// // Can be replaced with saturating_add directly on NonZeroU32 once
217// // <https://github.com/rust-lang/rust/issues/84186> is stable
218// in_transaction.transaction_depth =
219// NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
220// .expect("nz + nz is always non-zero");
221// Ok(())
222// }
223// (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
224// // This sets `transaction_depth` to `None` as soon as we reach zero
225// match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
226// Some(depth) => in_transaction.transaction_depth = depth,
227// None => self.in_transaction = None,
228// }
229// Ok(())
230// }
231// (None, TransactionDepthChange::IncreaseDepth) => {
232// self.in_transaction = Some(InTransactionStatus {
233// transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
234// top_level_transaction_requires_rollback: false,
235// test_transaction: false,
236// });
237// Ok(())
238// }
239// (None, TransactionDepthChange::DecreaseDepth) => {
240// // We screwed up something somewhere
241// // we cannot decrease the transaction count if
242// // we are not inside a transaction
243// Err(Error::NotInTransaction)
244// }
245// }
246// }
247// }
248
249// /// Represents a change to apply to the depth of a transaction
250// #[derive(Debug, Clone, Copy)]
251// pub enum TransactionDepthChange {
252// /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
253// IncreaseDepth,
254// /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
255// DecreaseDepth,
256// }
257
258impl AnsiTransactionManager {
259 fn get_transaction_state<Conn>(
260 conn: &mut Conn,
261 ) -> QueryResult<&mut ValidTransactionManagerStatus>
262 where
263 Conn: AsyncConnection<TransactionManager = Self>,
264 {
265 conn.transaction_state().status.transaction_state()
266 }
267
268 /// Begin a transaction with custom SQL
269 ///
270 /// This is used by connections to implement more complex transaction APIs
271 /// to set things such as isolation levels.
272 /// Returns an error if already inside of a transaction.
273 pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
274 where
275 Conn: AsyncConnection<TransactionManager = Self>,
276 {
277 let state = Self::get_transaction_state(conn)?;
278 match state.transaction_depth() {
279 None => {
280 conn.batch_execute(sql).await?;
281 Self::get_transaction_state(conn)?
282 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
283 Ok(())
284 }
285 Some(_depth) => Err(Error::AlreadyInTransaction),
286 }
287 }
288}
289
290#[async_trait::async_trait]
291impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
292where
293 Conn: AsyncConnection<TransactionManager = Self>,
294{
295 type TransactionStateData = Self;
296
297 async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
298 let transaction_state = Self::get_transaction_state(conn)?;
299 let start_transaction_sql = match transaction_state.transaction_depth() {
300 None => Cow::from("BEGIN"),
301 Some(transaction_depth) => {
302 Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
303 }
304 };
305 let depth = transaction_state
306 .transaction_depth()
307 .and_then(|d| d.checked_add(1))
308 .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
309 conn.instrumentation()
310 .on_connection_event(InstrumentationEvent::begin_transaction(depth));
311 conn.batch_execute(&start_transaction_sql).await?;
312 Self::get_transaction_state(conn)?
313 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
314
315 Ok(())
316 }
317
318 async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
319 let transaction_state = Self::get_transaction_state(conn)?;
320
321 let (
322 (rollback_sql, rolling_back_top_level),
323 requires_rollback_maybe_up_to_top_level_before_execute,
324 ) = match transaction_state.in_transaction {
325 Some(ref in_transaction) => (
326 match in_transaction.transaction_depth.get() {
327 1 => (Cow::Borrowed("ROLLBACK"), true),
328 depth_gt1 => (
329 Cow::Owned(format!(
330 "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
331 depth_gt1 - 1
332 )),
333 false,
334 ),
335 },
336 in_transaction.requires_rollback_maybe_up_to_top_level,
337 ),
338 None => return Err(Error::NotInTransaction),
339 };
340
341 let depth = transaction_state
342 .transaction_depth()
343 .expect("We know that we are in a transaction here");
344 conn.instrumentation()
345 .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
346
347 match conn.batch_execute(&rollback_sql).await {
348 Ok(()) => {
349 match Self::get_transaction_state(conn)?
350 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
351 {
352 Ok(()) => {}
353 Err(Error::NotInTransaction) if rolling_back_top_level => {
354 // Transaction exit may have already been detected by connection
355 // implementation. It's fine.
356 }
357 Err(e) => return Err(e),
358 }
359 Ok(())
360 }
361 Err(rollback_error) => {
362 let tm_status = Self::transaction_manager_status_mut(conn);
363 match tm_status {
364 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
365 in_transaction:
366 Some(InTransactionStatus {
367 transaction_depth,
368 requires_rollback_maybe_up_to_top_level,
369 ..
370 }),
371 ..
372 }) if transaction_depth.get() > 1 => {
373 // A savepoint failed to rollback - we may still attempt to repair
374 // the connection by rolling back higher levels.
375
376 // To make it easier on the user (that they don't have to really
377 // look at actual transaction depth and can just rely on the number
378 // of times they have called begin/commit/rollback) we still
379 // decrement here:
380 *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
381 .expect("Depth was checked to be > 1");
382 *requires_rollback_maybe_up_to_top_level = true;
383 if requires_rollback_maybe_up_to_top_level_before_execute {
384 // In that case, we tolerate that savepoint releases fail
385 // -> we should ignore errors
386 return Ok(());
387 }
388 }
389 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
390 in_transaction: None,
391 ..
392 }) => {
393 // we would have returned `NotInTransaction` if that was already the state
394 // before we made our call
395 // => Transaction manager status has been fixed by the underlying connection
396 // so we don't need to set_in_error
397 }
398 _ => tm_status.set_in_error(),
399 }
400 Err(rollback_error)
401 }
402 }
403 }
404
405 /// If the transaction fails to commit due to a `SerializationFailure` or a
406 /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
407 /// the original error will be returned, otherwise the error generated by the rollback
408 /// will be returned. In the second case the connection will be considered broken
409 /// as it contains a uncommitted unabortable open transaction.
410 async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
411 let transaction_state = Self::get_transaction_state(conn)?;
412 let transaction_depth = transaction_state.transaction_depth();
413 let (commit_sql, committing_top_level) = match transaction_depth {
414 None => return Err(Error::NotInTransaction),
415 Some(transaction_depth) if transaction_depth.get() == 1 => {
416 (Cow::Borrowed("COMMIT"), true)
417 }
418 Some(transaction_depth) => (
419 Cow::Owned(format!(
420 "RELEASE SAVEPOINT diesel_savepoint_{}",
421 transaction_depth.get() - 1
422 )),
423 false,
424 ),
425 };
426 let depth = transaction_state
427 .transaction_depth()
428 .expect("We know that we are in a transaction here");
429 conn.instrumentation()
430 .on_connection_event(InstrumentationEvent::commit_transaction(depth));
431
432 match conn.batch_execute(&commit_sql).await {
433 Ok(()) => {
434 match Self::get_transaction_state(conn)?
435 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
436 {
437 Ok(()) => {}
438 Err(Error::NotInTransaction) if committing_top_level => {
439 // Transaction exit may have already been detected by connection.
440 // It's fine
441 }
442 Err(e) => return Err(e),
443 }
444 Ok(())
445 }
446 Err(commit_error) => {
447 if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
448 in_transaction:
449 Some(InTransactionStatus {
450 requires_rollback_maybe_up_to_top_level: true,
451 ..
452 }),
453 ..
454 }) = conn.transaction_state().status
455 {
456 match Self::rollback_transaction(conn).await {
457 Ok(()) => {}
458 Err(rollback_error) => {
459 conn.transaction_state().status.set_in_error();
460 return Err(Error::RollbackErrorOnCommit {
461 rollback_error: Box::new(rollback_error),
462 commit_error: Box::new(commit_error),
463 });
464 }
465 }
466 }
467 Err(commit_error)
468 }
469 }
470 }
471
472 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
473 &mut conn.transaction_state().status
474 }
475}