use crate::client::{Addr, SocketConfig};
use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::MakeTlsConnect;
use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use rand::seq::SliceRandom;
use std::task::Poll;
use std::{cmp, io};
use tokio::net;
pub async fn connect<T>(
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
if config.host.is_empty() && config.hostaddr.is_empty() {
return Err(Error::config("both host and hostaddr are missing".into()));
}
if !config.host.is_empty()
&& !config.hostaddr.is_empty()
&& config.host.len() != config.hostaddr.len()
{
let msg = format!(
"number of hosts ({}) is different from number of hostaddrs ({})",
config.host.len(),
config.hostaddr.len(),
);
return Err(Error::config(msg.into()));
}
let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
if config.port.len() > 1 && config.port.len() != num_hosts {
return Err(Error::config("invalid number of ports".into()));
}
let mut indices = (0..num_hosts).collect::<Vec<_>>();
if config.load_balance_hosts == LoadBalanceHosts::Random {
indices.shuffle(&mut rand::thread_rng());
}
let mut error = None;
for i in indices {
let host = config.host.get(i);
let hostaddr = config.hostaddr.get(i);
let port = config
.port
.get(i)
.or_else(|| config.port.first())
.copied()
.unwrap_or(5432);
let hostname = match host {
Some(Host::Tcp(host)) => Some(host.clone()),
#[cfg(unix)]
Some(Host::Unix(_)) => None,
None => None,
};
let addr = match hostaddr {
Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
None => host.cloned().unwrap(),
};
match connect_host(addr, hostname, port, &mut tls, config).await {
Ok((client, connection)) => return Ok((client, connection)),
Err(e) => error = Some(e),
}
}
Err(error.unwrap())
}
async fn connect_host<T>(
host: Host,
hostname: Option<String>,
port: u16,
tls: &mut T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
match host {
Host::Tcp(host) => {
let mut addrs = net::lookup_host((&*host, port))
.await
.map_err(Error::connect)?
.collect::<Vec<_>>();
if config.load_balance_hosts == LoadBalanceHosts::Random {
addrs.shuffle(&mut rand::thread_rng());
}
let mut last_err = None;
for addr in addrs {
match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
.await
{
Ok(stream) => return Ok(stream),
Err(e) => {
last_err = Some(e);
continue;
}
};
}
Err(last_err.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}))
}
#[cfg(unix)]
Host::Unix(path) => {
connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
}
}
}
async fn connect_once<T>(
addr: Addr,
hostname: Option<&str>,
port: u16,
tls: &mut T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
let socket = connect_socket(
&addr,
port,
config.connect_timeout,
config.tcp_user_timeout,
if config.keepalives {
Some(&config.keepalive_config)
} else {
None
},
)
.await?;
let tls = tls
.make_tls_connect(hostname.unwrap_or(""))
.map_err(|e| Error::tls(e.into()))?;
let has_hostname = hostname.is_some();
let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
if config.target_session_attrs != TargetSessionAttrs::Any {
let rows = client.simple_query_raw("SHOW transaction_read_only");
pin_mut!(rows);
let rows = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Err(Error::closed()));
}
rows.as_mut().poll(cx)
})
.await?;
pin_mut!(rows);
loop {
let next = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Some(Err(Error::closed())));
}
rows.as_mut().poll_next(cx)
});
match next.await.transpose()? {
Some(SimpleQueryMessage::Row(row)) => {
let read_only_result = row.try_get(0)?;
if read_only_result == Some("on")
&& config.target_session_attrs == TargetSessionAttrs::ReadWrite
{
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)));
} else if read_only_result == Some("off")
&& config.target_session_attrs == TargetSessionAttrs::ReadOnly
{
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database is not read only",
)));
} else {
break;
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
}
}
}
client.set_socket_config(SocketConfig {
addr,
hostname: hostname.map(|s| s.to_string()),
port,
connect_timeout: config.connect_timeout,
tcp_user_timeout: config.tcp_user_timeout,
keepalive: if config.keepalives {
Some(config.keepalive_config.clone())
} else {
None
},
});
Ok((client, connection))
}