use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::copy_out::CopyOutStream;
use crate::query::RowStream;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::types::{BorrowToSql, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
SimpleQueryMessage, Statement, ToStatement,
};
use bytes::Buf;
use futures_util::TryStreamExt;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
pub struct Transaction<'a> {
client: &'a mut Client,
savepoint: Option<Savepoint>,
done: bool,
}
struct Savepoint {
name: String,
depth: u32,
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if self.done {
return;
}
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("ROLLBACK TO {}", sp.name)
} else {
"ROLLBACK".to_string()
};
let buf = self.client.inner().with_buf(|buf| {
frontend::query(&query, buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
savepoint: None,
done: false,
}
}
pub async fn commit(mut self) -> Result<(), Error> {
self.done = true;
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("RELEASE {}", sp.name)
} else {
"COMMIT".to_string()
};
self.client.batch_execute(&query).await
}
pub async fn rollback(mut self) -> Result<(), Error> {
self.done = true;
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("ROLLBACK TO {}", sp.name)
} else {
"ROLLBACK".to_string()
};
self.client.batch_execute(&query).await
}
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.client.prepare(query).await
}
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
self.client.prepare_typed(query, parameter_types).await
}
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.client.query(statement, params).await
}
pub async fn query_one<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Row, Error>
where
T: ?Sized + ToStatement,
{
self.client.query_one(statement, params).await
}
pub async fn query_opt<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.client.query_opt(statement, params).await
}
pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw(statement, params).await
}
pub async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.client.query_typed(statement, params).await
}
pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
where
P: BorrowToSql,
I: IntoIterator<Item = (P, Type)>,
{
self.client.query_typed_raw(query, params).await
}
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.client.execute(statement, params).await
}
pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
self.client.execute_raw(statement, params).await
}
pub async fn bind<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Portal, Error>
where
T: ?Sized + ToStatement,
{
self.bind_raw(statement, slice_iter(params)).await
}
pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
where
T: ?Sized + ToStatement,
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self.client).await?;
bind::bind(self.client.inner(), statement, params).await
}
pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.query_portal_raw(portal, max_rows)
.await?
.try_collect()
.await
}
pub async fn query_portal_raw(
&self,
portal: &Portal,
max_rows: i32,
) -> Result<RowStream, Error> {
query::query_portal(self.client.inner(), portal, max_rows).await
}
pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
where
T: ?Sized + ToStatement,
U: Buf + 'static + Send,
{
self.client.copy_in(statement).await
}
pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
where
T: ?Sized + ToStatement,
{
self.client.copy_out(statement).await
}
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.client.simple_query(query).await
}
pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
self.client.batch_execute(query).await
}
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
#[cfg(feature = "runtime")]
#[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where
T: MakeTlsConnect<Socket>,
{
#[allow(deprecated)]
self.client.cancel_query(tls).await
}
#[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
#[allow(deprecated)]
self.client.cancel_query_raw(stream, tls).await
}
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
self._savepoint(None).await
}
pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
self._savepoint(Some(name.into())).await
}
async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
let name = name.unwrap_or_else(|| format!("sp_{}", depth));
let query = format!("SAVEPOINT {}", name);
self.batch_execute(&query).await?;
Ok(Transaction {
client: self.client,
savepoint: Some(Savepoint { name, depth }),
done: false,
})
}
pub fn client(&self) -> &Client {
self.client
}
}