tokio_postgres/
config.rs

1//! Connection configuration.
2
3#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect_raw::connect_raw;
6#[cfg(not(target_arch = "wasm32"))]
7use crate::keepalive::KeepaliveConfig;
8#[cfg(feature = "runtime")]
9use crate::tls::MakeTlsConnect;
10use crate::tls::TlsConnect;
11#[cfg(feature = "runtime")]
12use crate::Socket;
13use crate::{Client, Connection, Error};
14use std::borrow::Cow;
15#[cfg(unix)]
16use std::ffi::OsStr;
17use std::net::IpAddr;
18use std::ops::Deref;
19#[cfg(unix)]
20use std::os::unix::ffi::OsStrExt;
21#[cfg(unix)]
22use std::path::{Path, PathBuf};
23use std::str;
24use std::str::FromStr;
25use std::time::Duration;
26use std::{error, fmt, iter, mem};
27use tokio::io::{AsyncRead, AsyncWrite};
28
29/// Properties required of a session.
30#[derive(Debug, Copy, Clone, PartialEq, Eq)]
31#[non_exhaustive]
32pub enum TargetSessionAttrs {
33    /// No special properties are required.
34    Any,
35    /// The session must allow writes.
36    ReadWrite,
37    /// The session allow only reads.
38    ReadOnly,
39}
40
41/// TLS configuration.
42#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum SslMode {
45    /// Do not use TLS.
46    Disable,
47    /// Attempt to connect with TLS but allow sessions without.
48    Prefer,
49    /// Require the use of TLS.
50    Require,
51}
52
53/// Channel binding configuration.
54#[derive(Debug, Copy, Clone, PartialEq, Eq)]
55#[non_exhaustive]
56pub enum ChannelBinding {
57    /// Do not use channel binding.
58    Disable,
59    /// Attempt to use channel binding but allow sessions without.
60    Prefer,
61    /// Require the use of channel binding.
62    Require,
63}
64
65/// Load balancing configuration.
66#[derive(Debug, Copy, Clone, PartialEq, Eq)]
67#[non_exhaustive]
68pub enum LoadBalanceHosts {
69    /// Make connection attempts to hosts in the order provided.
70    Disable,
71    /// Make connection attempts to hosts in a random order.
72    Random,
73}
74
75/// A host specification.
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum Host {
78    /// A TCP hostname.
79    Tcp(String),
80    /// A path to a directory containing the server's Unix socket.
81    ///
82    /// This variant is only available on Unix platforms.
83    #[cfg(unix)]
84    Unix(PathBuf),
85}
86
87/// Connection configuration.
88///
89/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
90///
91/// # Key-Value
92///
93/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain
94/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped.
95///
96/// ## Keys
97///
98/// * `user` - The username to authenticate with. Defaults to the user executing this process.
99/// * `password` - The password to authenticate with.
100/// * `dbname` - The name of the database to connect to. Defaults to the username.
101/// * `options` - Command line options used to configure the server.
102/// * `application_name` - Sets the `application_name` parameter on the server.
103/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
104///     if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
105/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
106///     path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
107///     can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
108///     with the `connect` method.
109/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format,
110///     e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses.
111///     If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address,
112///     or if host specifies an IP address, that value will be used directly.
113///     Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications
114///     with time constraints. However, a host name is required for TLS certificate verification.
115///     Specifically:
116///         * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address.
117///             The connection attempt will fail if the authentication method requires a host name;
118///         * If `host` is specified without `hostaddr`, a host name lookup occurs;
119///         * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address.
120///             The value for `host` is ignored unless the authentication method requires it,
121///             in which case it will be used as the host name.
122/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be
123///     either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if
124///     omitted or the empty string.
125/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames
126///     can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout.
127/// * `tcp_user_timeout` - The time limit that transmitted data may remain unacknowledged before a connection is forcibly closed.
128///     This is ignored for Unix domain socket connections. It is only supported on systems where TCP_USER_TIMEOUT is available
129///     and will default to the system default if omitted or set to 0; on other systems, it has no effect.
130/// * `keepalives` - Controls the use of TCP keepalive. A value of 0 disables keepalive and nonzero integers enable it.
131///     This option is ignored when connecting with Unix sockets. Defaults to on.
132/// * `keepalives_idle` - The number of seconds of inactivity after which a keepalive message is sent to the server.
133///     This option is ignored when connecting with Unix sockets. Defaults to 2 hours.
134/// * `keepalives_interval` - The time interval between TCP keepalive probes.
135///     This option is ignored when connecting with Unix sockets.
136/// * `keepalives_retries` - The maximum number of TCP keepalive probes that will be sent before dropping a connection.
137///     This option is ignored when connecting with Unix sockets.
138/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
139///     the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
140///     in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
141/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
142///     binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
143///     If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
144/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and
145///     addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter
146///     is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to
147///     `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried
148///     in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults
149///     to `disable`.
150///
151/// ## Examples
152///
153/// ```not_rust
154/// host=localhost user=postgres connect_timeout=10 keepalives=0
155/// ```
156///
157/// ```not_rust
158/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces'
159/// ```
160///
161/// ```not_rust
162/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write
163/// ```
164///
165/// ```not_rust
166/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write
167/// ```
168///
169/// # Url
170///
171/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional,
172/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple
173/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded,
174/// as the path component of the URL specifies the database name.
175///
176/// ## Examples
177///
178/// ```not_rust
179/// postgresql://user@localhost
180/// ```
181///
182/// ```not_rust
183/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10
184/// ```
185///
186/// ```not_rust
187/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write
188/// ```
189///
190/// ```not_rust
191/// postgresql:///mydb?user=user&host=/var/lib/postgresql
192/// ```
193#[derive(Clone, PartialEq, Eq)]
194pub struct Config {
195    pub(crate) user: Option<String>,
196    pub(crate) password: Option<Vec<u8>>,
197    pub(crate) dbname: Option<String>,
198    pub(crate) options: Option<String>,
199    pub(crate) application_name: Option<String>,
200    pub(crate) ssl_mode: SslMode,
201    pub(crate) host: Vec<Host>,
202    pub(crate) hostaddr: Vec<IpAddr>,
203    pub(crate) port: Vec<u16>,
204    pub(crate) connect_timeout: Option<Duration>,
205    pub(crate) tcp_user_timeout: Option<Duration>,
206    pub(crate) keepalives: bool,
207    #[cfg(not(target_arch = "wasm32"))]
208    pub(crate) keepalive_config: KeepaliveConfig,
209    pub(crate) target_session_attrs: TargetSessionAttrs,
210    pub(crate) channel_binding: ChannelBinding,
211    pub(crate) load_balance_hosts: LoadBalanceHosts,
212}
213
214impl Default for Config {
215    fn default() -> Config {
216        Config::new()
217    }
218}
219
220impl Config {
221    /// Creates a new configuration.
222    pub fn new() -> Config {
223        Config {
224            user: None,
225            password: None,
226            dbname: None,
227            options: None,
228            application_name: None,
229            ssl_mode: SslMode::Prefer,
230            host: vec![],
231            hostaddr: vec![],
232            port: vec![],
233            connect_timeout: None,
234            tcp_user_timeout: None,
235            keepalives: true,
236            #[cfg(not(target_arch = "wasm32"))]
237            keepalive_config: KeepaliveConfig {
238                idle: Duration::from_secs(2 * 60 * 60),
239                interval: None,
240                retries: None,
241            },
242            target_session_attrs: TargetSessionAttrs::Any,
243            channel_binding: ChannelBinding::Prefer,
244            load_balance_hosts: LoadBalanceHosts::Disable,
245        }
246    }
247
248    /// Sets the user to authenticate with.
249    ///
250    /// Defaults to the user executing this process.
251    pub fn user(&mut self, user: impl Into<String>) -> &mut Config {
252        self.user = Some(user.into());
253        self
254    }
255
256    /// Gets the user to authenticate with, if one has been configured with
257    /// the `user` method.
258    pub fn get_user(&self) -> Option<&str> {
259        self.user.as_deref()
260    }
261
262    /// Sets the password to authenticate with.
263    pub fn password<T>(&mut self, password: T) -> &mut Config
264    where
265        T: AsRef<[u8]>,
266    {
267        self.password = Some(password.as_ref().to_vec());
268        self
269    }
270
271    /// Gets the password to authenticate with, if one has been configured with
272    /// the `password` method.
273    pub fn get_password(&self) -> Option<&[u8]> {
274        self.password.as_deref()
275    }
276
277    /// Sets the name of the database to connect to.
278    ///
279    /// Defaults to the user.
280    pub fn dbname(&mut self, dbname: impl Into<String>) -> &mut Config {
281        self.dbname = Some(dbname.into());
282        self
283    }
284
285    /// Gets the name of the database to connect to, if one has been configured
286    /// with the `dbname` method.
287    pub fn get_dbname(&self) -> Option<&str> {
288        self.dbname.as_deref()
289    }
290
291    /// Sets command line options used to configure the server.
292    pub fn options(&mut self, options: impl Into<String>) -> &mut Config {
293        self.options = Some(options.into());
294        self
295    }
296
297    /// Gets the command line options used to configure the server, if the
298    /// options have been set with the `options` method.
299    pub fn get_options(&self) -> Option<&str> {
300        self.options.as_deref()
301    }
302
303    /// Sets the value of the `application_name` runtime parameter.
304    pub fn application_name(&mut self, application_name: impl Into<String>) -> &mut Config {
305        self.application_name = Some(application_name.into());
306        self
307    }
308
309    /// Gets the value of the `application_name` runtime parameter, if it has
310    /// been set with the `application_name` method.
311    pub fn get_application_name(&self) -> Option<&str> {
312        self.application_name.as_deref()
313    }
314
315    /// Sets the SSL configuration.
316    ///
317    /// Defaults to `prefer`.
318    pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
319        self.ssl_mode = ssl_mode;
320        self
321    }
322
323    /// Gets the SSL configuration.
324    pub fn get_ssl_mode(&self) -> SslMode {
325        self.ssl_mode
326    }
327
328    /// Adds a host to the configuration.
329    ///
330    /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
331    /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
332    /// There must be either no hosts, or the same number of hosts as hostaddrs.
333    pub fn host(&mut self, host: impl Into<String>) -> &mut Config {
334        let host = host.into();
335
336        #[cfg(unix)]
337        {
338            if host.starts_with('/') {
339                return self.host_path(host);
340            }
341        }
342
343        self.host.push(Host::Tcp(host));
344        self
345    }
346
347    /// Gets the hosts that have been added to the configuration with `host`.
348    pub fn get_hosts(&self) -> &[Host] {
349        &self.host
350    }
351
352    /// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
353    pub fn get_hostaddrs(&self) -> &[IpAddr] {
354        self.hostaddr.deref()
355    }
356
357    /// Adds a Unix socket host to the configuration.
358    ///
359    /// Unlike `host`, this method allows non-UTF8 paths.
360    #[cfg(unix)]
361    pub fn host_path<T>(&mut self, host: T) -> &mut Config
362    where
363        T: AsRef<Path>,
364    {
365        self.host.push(Host::Unix(host.as_ref().to_path_buf()));
366        self
367    }
368
369    /// Adds a hostaddr to the configuration.
370    ///
371    /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
372    /// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
373    pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
374        self.hostaddr.push(hostaddr);
375        self
376    }
377
378    /// Adds a port to the configuration.
379    ///
380    /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
381    /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
382    /// as hosts.
383    pub fn port(&mut self, port: u16) -> &mut Config {
384        self.port.push(port);
385        self
386    }
387
388    /// Gets the ports that have been added to the configuration with `port`.
389    pub fn get_ports(&self) -> &[u16] {
390        &self.port
391    }
392
393    /// Sets the timeout applied to socket-level connection attempts.
394    ///
395    /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
396    /// host separately. Defaults to no limit.
397    pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
398        self.connect_timeout = Some(connect_timeout);
399        self
400    }
401
402    /// Gets the connection timeout, if one has been set with the
403    /// `connect_timeout` method.
404    pub fn get_connect_timeout(&self) -> Option<&Duration> {
405        self.connect_timeout.as_ref()
406    }
407
408    /// Sets the TCP user timeout.
409    ///
410    /// This is ignored for Unix domain socket connections. It is only supported on systems where
411    /// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
412    /// on other systems, it has no effect.
413    pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
414        self.tcp_user_timeout = Some(tcp_user_timeout);
415        self
416    }
417
418    /// Gets the TCP user timeout, if one has been set with the
419    /// `user_timeout` method.
420    pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
421        self.tcp_user_timeout.as_ref()
422    }
423
424    /// Controls the use of TCP keepalive.
425    ///
426    /// This is ignored for Unix domain socket connections. Defaults to `true`.
427    pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
428        self.keepalives = keepalives;
429        self
430    }
431
432    /// Reports whether TCP keepalives will be used.
433    pub fn get_keepalives(&self) -> bool {
434        self.keepalives
435    }
436
437    /// Sets the amount of idle time before a keepalive packet is sent on the connection.
438    ///
439    /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
440    #[cfg(not(target_arch = "wasm32"))]
441    pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
442        self.keepalive_config.idle = keepalives_idle;
443        self
444    }
445
446    /// Gets the configured amount of idle time before a keepalive packet will
447    /// be sent on the connection.
448    #[cfg(not(target_arch = "wasm32"))]
449    pub fn get_keepalives_idle(&self) -> Duration {
450        self.keepalive_config.idle
451    }
452
453    /// Sets the time interval between TCP keepalive probes.
454    /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field.
455    ///
456    /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
457    #[cfg(not(target_arch = "wasm32"))]
458    pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
459        self.keepalive_config.interval = Some(keepalives_interval);
460        self
461    }
462
463    /// Gets the time interval between TCP keepalive probes.
464    #[cfg(not(target_arch = "wasm32"))]
465    pub fn get_keepalives_interval(&self) -> Option<Duration> {
466        self.keepalive_config.interval
467    }
468
469    /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
470    ///
471    /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
472    #[cfg(not(target_arch = "wasm32"))]
473    pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
474        self.keepalive_config.retries = Some(keepalives_retries);
475        self
476    }
477
478    /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
479    #[cfg(not(target_arch = "wasm32"))]
480    pub fn get_keepalives_retries(&self) -> Option<u32> {
481        self.keepalive_config.retries
482    }
483
484    /// Sets the requirements of the session.
485    ///
486    /// This can be used to connect to the primary server in a clustered database rather than one of the read-only
487    /// secondary servers. Defaults to `Any`.
488    pub fn target_session_attrs(
489        &mut self,
490        target_session_attrs: TargetSessionAttrs,
491    ) -> &mut Config {
492        self.target_session_attrs = target_session_attrs;
493        self
494    }
495
496    /// Gets the requirements of the session.
497    pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
498        self.target_session_attrs
499    }
500
501    /// Sets the channel binding behavior.
502    ///
503    /// Defaults to `prefer`.
504    pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
505        self.channel_binding = channel_binding;
506        self
507    }
508
509    /// Gets the channel binding behavior.
510    pub fn get_channel_binding(&self) -> ChannelBinding {
511        self.channel_binding
512    }
513
514    /// Sets the host load balancing behavior.
515    ///
516    /// Defaults to `disable`.
517    pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
518        self.load_balance_hosts = load_balance_hosts;
519        self
520    }
521
522    /// Gets the host load balancing behavior.
523    pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
524        self.load_balance_hosts
525    }
526
527    fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
528        match key {
529            "user" => {
530                self.user(value);
531            }
532            "password" => {
533                self.password(value);
534            }
535            "dbname" => {
536                self.dbname(value);
537            }
538            "options" => {
539                self.options(value);
540            }
541            "application_name" => {
542                self.application_name(value);
543            }
544            "sslmode" => {
545                let mode = match value {
546                    "disable" => SslMode::Disable,
547                    "prefer" => SslMode::Prefer,
548                    "require" => SslMode::Require,
549                    _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
550                };
551                self.ssl_mode(mode);
552            }
553            "host" => {
554                for host in value.split(',') {
555                    self.host(host);
556                }
557            }
558            "hostaddr" => {
559                for hostaddr in value.split(',') {
560                    let addr = hostaddr
561                        .parse()
562                        .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
563                    self.hostaddr(addr);
564                }
565            }
566            "port" => {
567                for port in value.split(',') {
568                    let port = if port.is_empty() {
569                        5432
570                    } else {
571                        port.parse()
572                            .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
573                    };
574                    self.port(port);
575                }
576            }
577            "connect_timeout" => {
578                let timeout = value
579                    .parse::<i64>()
580                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
581                if timeout > 0 {
582                    self.connect_timeout(Duration::from_secs(timeout as u64));
583                }
584            }
585            "tcp_user_timeout" => {
586                let timeout = value
587                    .parse::<i64>()
588                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?;
589                if timeout > 0 {
590                    self.tcp_user_timeout(Duration::from_secs(timeout as u64));
591                }
592            }
593            #[cfg(not(target_arch = "wasm32"))]
594            "keepalives" => {
595                let keepalives = value
596                    .parse::<u64>()
597                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
598                self.keepalives(keepalives != 0);
599            }
600            #[cfg(not(target_arch = "wasm32"))]
601            "keepalives_idle" => {
602                let keepalives_idle = value
603                    .parse::<i64>()
604                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
605                if keepalives_idle > 0 {
606                    self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
607                }
608            }
609            #[cfg(not(target_arch = "wasm32"))]
610            "keepalives_interval" => {
611                let keepalives_interval = value.parse::<i64>().map_err(|_| {
612                    Error::config_parse(Box::new(InvalidValue("keepalives_interval")))
613                })?;
614                if keepalives_interval > 0 {
615                    self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
616                }
617            }
618            #[cfg(not(target_arch = "wasm32"))]
619            "keepalives_retries" => {
620                let keepalives_retries = value.parse::<u32>().map_err(|_| {
621                    Error::config_parse(Box::new(InvalidValue("keepalives_retries")))
622                })?;
623                self.keepalives_retries(keepalives_retries);
624            }
625            "target_session_attrs" => {
626                let target_session_attrs = match value {
627                    "any" => TargetSessionAttrs::Any,
628                    "read-write" => TargetSessionAttrs::ReadWrite,
629                    "read-only" => TargetSessionAttrs::ReadOnly,
630                    _ => {
631                        return Err(Error::config_parse(Box::new(InvalidValue(
632                            "target_session_attrs",
633                        ))));
634                    }
635                };
636                self.target_session_attrs(target_session_attrs);
637            }
638            "channel_binding" => {
639                let channel_binding = match value {
640                    "disable" => ChannelBinding::Disable,
641                    "prefer" => ChannelBinding::Prefer,
642                    "require" => ChannelBinding::Require,
643                    _ => {
644                        return Err(Error::config_parse(Box::new(InvalidValue(
645                            "channel_binding",
646                        ))))
647                    }
648                };
649                self.channel_binding(channel_binding);
650            }
651            "load_balance_hosts" => {
652                let load_balance_hosts = match value {
653                    "disable" => LoadBalanceHosts::Disable,
654                    "random" => LoadBalanceHosts::Random,
655                    _ => {
656                        return Err(Error::config_parse(Box::new(InvalidValue(
657                            "load_balance_hosts",
658                        ))))
659                    }
660                };
661                self.load_balance_hosts(load_balance_hosts);
662            }
663            key => {
664                return Err(Error::config_parse(Box::new(UnknownOption(
665                    key.to_string(),
666                ))));
667            }
668        }
669
670        Ok(())
671    }
672
673    /// Opens a connection to a PostgreSQL database.
674    ///
675    /// Requires the `runtime` Cargo feature (enabled by default).
676    #[cfg(feature = "runtime")]
677    pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
678    where
679        T: MakeTlsConnect<Socket>,
680    {
681        connect(tls, self).await
682    }
683
684    /// Connects to a PostgreSQL database over an arbitrary stream.
685    ///
686    /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
687    pub async fn connect_raw<S, T>(
688        &self,
689        stream: S,
690        tls: T,
691    ) -> Result<(Client, Connection<S, T::Stream>), Error>
692    where
693        S: AsyncRead + AsyncWrite + Unpin,
694        T: TlsConnect<S>,
695    {
696        connect_raw(stream, tls, true, self).await
697    }
698}
699
700impl FromStr for Config {
701    type Err = Error;
702
703    fn from_str(s: &str) -> Result<Config, Error> {
704        match UrlParser::parse(s)? {
705            Some(config) => Ok(config),
706            None => Parser::parse(s),
707        }
708    }
709}
710
711// Omit password from debug output
712impl fmt::Debug for Config {
713    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
714        struct Redaction {}
715        impl fmt::Debug for Redaction {
716            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
717                write!(f, "_")
718            }
719        }
720
721        let mut config_dbg = &mut f.debug_struct("Config");
722        config_dbg = config_dbg
723            .field("user", &self.user)
724            .field("password", &self.password.as_ref().map(|_| Redaction {}))
725            .field("dbname", &self.dbname)
726            .field("options", &self.options)
727            .field("application_name", &self.application_name)
728            .field("ssl_mode", &self.ssl_mode)
729            .field("host", &self.host)
730            .field("hostaddr", &self.hostaddr)
731            .field("port", &self.port)
732            .field("connect_timeout", &self.connect_timeout)
733            .field("tcp_user_timeout", &self.tcp_user_timeout)
734            .field("keepalives", &self.keepalives);
735
736        #[cfg(not(target_arch = "wasm32"))]
737        {
738            config_dbg = config_dbg
739                .field("keepalives_idle", &self.keepalive_config.idle)
740                .field("keepalives_interval", &self.keepalive_config.interval)
741                .field("keepalives_retries", &self.keepalive_config.retries);
742        }
743
744        config_dbg
745            .field("target_session_attrs", &self.target_session_attrs)
746            .field("channel_binding", &self.channel_binding)
747            .finish()
748    }
749}
750
751#[derive(Debug)]
752struct UnknownOption(String);
753
754impl fmt::Display for UnknownOption {
755    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
756        write!(fmt, "unknown option `{}`", self.0)
757    }
758}
759
760impl error::Error for UnknownOption {}
761
762#[derive(Debug)]
763struct InvalidValue(&'static str);
764
765impl fmt::Display for InvalidValue {
766    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
767        write!(fmt, "invalid value for option `{}`", self.0)
768    }
769}
770
771impl error::Error for InvalidValue {}
772
773struct Parser<'a> {
774    s: &'a str,
775    it: iter::Peekable<str::CharIndices<'a>>,
776}
777
778impl<'a> Parser<'a> {
779    fn parse(s: &'a str) -> Result<Config, Error> {
780        let mut parser = Parser {
781            s,
782            it: s.char_indices().peekable(),
783        };
784
785        let mut config = Config::new();
786
787        while let Some((key, value)) = parser.parameter()? {
788            config.param(key, &value)?;
789        }
790
791        Ok(config)
792    }
793
794    fn skip_ws(&mut self) {
795        self.take_while(char::is_whitespace);
796    }
797
798    fn take_while<F>(&mut self, f: F) -> &'a str
799    where
800        F: Fn(char) -> bool,
801    {
802        let start = match self.it.peek() {
803            Some(&(i, _)) => i,
804            None => return "",
805        };
806
807        loop {
808            match self.it.peek() {
809                Some(&(_, c)) if f(c) => {
810                    self.it.next();
811                }
812                Some(&(i, _)) => return &self.s[start..i],
813                None => return &self.s[start..],
814            }
815        }
816    }
817
818    fn eat(&mut self, target: char) -> Result<(), Error> {
819        match self.it.next() {
820            Some((_, c)) if c == target => Ok(()),
821            Some((i, c)) => {
822                let m = format!(
823                    "unexpected character at byte {}: expected `{}` but got `{}`",
824                    i, target, c
825                );
826                Err(Error::config_parse(m.into()))
827            }
828            None => Err(Error::config_parse("unexpected EOF".into())),
829        }
830    }
831
832    fn eat_if(&mut self, target: char) -> bool {
833        match self.it.peek() {
834            Some(&(_, c)) if c == target => {
835                self.it.next();
836                true
837            }
838            _ => false,
839        }
840    }
841
842    fn keyword(&mut self) -> Option<&'a str> {
843        let s = self.take_while(|c| match c {
844            c if c.is_whitespace() => false,
845            '=' => false,
846            _ => true,
847        });
848
849        if s.is_empty() {
850            None
851        } else {
852            Some(s)
853        }
854    }
855
856    fn value(&mut self) -> Result<String, Error> {
857        let value = if self.eat_if('\'') {
858            let value = self.quoted_value()?;
859            self.eat('\'')?;
860            value
861        } else {
862            self.simple_value()?
863        };
864
865        Ok(value)
866    }
867
868    fn simple_value(&mut self) -> Result<String, Error> {
869        let mut value = String::new();
870
871        while let Some(&(_, c)) = self.it.peek() {
872            if c.is_whitespace() {
873                break;
874            }
875
876            self.it.next();
877            if c == '\\' {
878                if let Some((_, c2)) = self.it.next() {
879                    value.push(c2);
880                }
881            } else {
882                value.push(c);
883            }
884        }
885
886        if value.is_empty() {
887            return Err(Error::config_parse("unexpected EOF".into()));
888        }
889
890        Ok(value)
891    }
892
893    fn quoted_value(&mut self) -> Result<String, Error> {
894        let mut value = String::new();
895
896        while let Some(&(_, c)) = self.it.peek() {
897            if c == '\'' {
898                return Ok(value);
899            }
900
901            self.it.next();
902            if c == '\\' {
903                if let Some((_, c2)) = self.it.next() {
904                    value.push(c2);
905                }
906            } else {
907                value.push(c);
908            }
909        }
910
911        Err(Error::config_parse(
912            "unterminated quoted connection parameter value".into(),
913        ))
914    }
915
916    fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
917        self.skip_ws();
918        let keyword = match self.keyword() {
919            Some(keyword) => keyword,
920            None => return Ok(None),
921        };
922        self.skip_ws();
923        self.eat('=')?;
924        self.skip_ws();
925        let value = self.value()?;
926
927        Ok(Some((keyword, value)))
928    }
929}
930
931// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
932struct UrlParser<'a> {
933    s: &'a str,
934    config: Config,
935}
936
937impl<'a> UrlParser<'a> {
938    fn parse(s: &'a str) -> Result<Option<Config>, Error> {
939        let s = match Self::remove_url_prefix(s) {
940            Some(s) => s,
941            None => return Ok(None),
942        };
943
944        let mut parser = UrlParser {
945            s,
946            config: Config::new(),
947        };
948
949        parser.parse_credentials()?;
950        parser.parse_host()?;
951        parser.parse_path()?;
952        parser.parse_params()?;
953
954        Ok(Some(parser.config))
955    }
956
957    fn remove_url_prefix(s: &str) -> Option<&str> {
958        for prefix in &["postgres://", "postgresql://"] {
959            if let Some(stripped) = s.strip_prefix(prefix) {
960                return Some(stripped);
961            }
962        }
963
964        None
965    }
966
967    fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
968        match self.s.find(end) {
969            Some(pos) => {
970                let (head, tail) = self.s.split_at(pos);
971                self.s = tail;
972                Some(head)
973            }
974            None => None,
975        }
976    }
977
978    fn take_all(&mut self) -> &'a str {
979        mem::take(&mut self.s)
980    }
981
982    fn eat_byte(&mut self) {
983        self.s = &self.s[1..];
984    }
985
986    fn parse_credentials(&mut self) -> Result<(), Error> {
987        let creds = match self.take_until(&['@']) {
988            Some(creds) => creds,
989            None => return Ok(()),
990        };
991        self.eat_byte();
992
993        let mut it = creds.splitn(2, ':');
994        let user = self.decode(it.next().unwrap())?;
995        self.config.user(user);
996
997        if let Some(password) = it.next() {
998            let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
999            self.config.password(password);
1000        }
1001
1002        Ok(())
1003    }
1004
1005    fn parse_host(&mut self) -> Result<(), Error> {
1006        let host = match self.take_until(&['/', '?']) {
1007            Some(host) => host,
1008            None => self.take_all(),
1009        };
1010
1011        if host.is_empty() {
1012            return Ok(());
1013        }
1014
1015        for chunk in host.split(',') {
1016            let (host, port) = if chunk.starts_with('[') {
1017                let idx = match chunk.find(']') {
1018                    Some(idx) => idx,
1019                    None => return Err(Error::config_parse(InvalidValue("host").into())),
1020                };
1021
1022                let host = &chunk[1..idx];
1023                let remaining = &chunk[idx + 1..];
1024                let port = if let Some(port) = remaining.strip_prefix(':') {
1025                    Some(port)
1026                } else if remaining.is_empty() {
1027                    None
1028                } else {
1029                    return Err(Error::config_parse(InvalidValue("host").into()));
1030                };
1031
1032                (host, port)
1033            } else {
1034                let mut it = chunk.splitn(2, ':');
1035                (it.next().unwrap(), it.next())
1036            };
1037
1038            self.host_param(host)?;
1039            let port = self.decode(port.unwrap_or("5432"))?;
1040            self.config.param("port", &port)?;
1041        }
1042
1043        Ok(())
1044    }
1045
1046    fn parse_path(&mut self) -> Result<(), Error> {
1047        if !self.s.starts_with('/') {
1048            return Ok(());
1049        }
1050        self.eat_byte();
1051
1052        let dbname = match self.take_until(&['?']) {
1053            Some(dbname) => dbname,
1054            None => self.take_all(),
1055        };
1056
1057        if !dbname.is_empty() {
1058            self.config.dbname(self.decode(dbname)?);
1059        }
1060
1061        Ok(())
1062    }
1063
1064    fn parse_params(&mut self) -> Result<(), Error> {
1065        if !self.s.starts_with('?') {
1066            return Ok(());
1067        }
1068        self.eat_byte();
1069
1070        while !self.s.is_empty() {
1071            let key = match self.take_until(&['=']) {
1072                Some(key) => self.decode(key)?,
1073                None => return Err(Error::config_parse("unterminated parameter".into())),
1074            };
1075            self.eat_byte();
1076
1077            let value = match self.take_until(&['&']) {
1078                Some(value) => {
1079                    self.eat_byte();
1080                    value
1081                }
1082                None => self.take_all(),
1083            };
1084
1085            if key == "host" {
1086                self.host_param(value)?;
1087            } else {
1088                let value = self.decode(value)?;
1089                self.config.param(&key, &value)?;
1090            }
1091        }
1092
1093        Ok(())
1094    }
1095
1096    #[cfg(unix)]
1097    fn host_param(&mut self, s: &str) -> Result<(), Error> {
1098        let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
1099        if decoded.first() == Some(&b'/') {
1100            self.config.host_path(OsStr::from_bytes(&decoded));
1101        } else {
1102            let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
1103            self.config.host(decoded);
1104        }
1105
1106        Ok(())
1107    }
1108
1109    #[cfg(not(unix))]
1110    fn host_param(&mut self, s: &str) -> Result<(), Error> {
1111        let s = self.decode(s)?;
1112        self.config.param("host", &s)
1113    }
1114
1115    fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
1116        percent_encoding::percent_decode(s.as_bytes())
1117            .decode_utf8()
1118            .map_err(|e| Error::config_parse(e.into()))
1119    }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124    use std::net::IpAddr;
1125
1126    use crate::{config::Host, Config};
1127
1128    #[test]
1129    fn test_simple_parsing() {
1130        let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
1131        let config = s.parse::<Config>().unwrap();
1132        assert_eq!(Some("pass_user"), config.get_user());
1133        assert_eq!(Some("postgres"), config.get_dbname());
1134        assert_eq!(
1135            [
1136                Host::Tcp("host1".to_string()),
1137                Host::Tcp("host2".to_string())
1138            ],
1139            config.get_hosts(),
1140        );
1141
1142        assert_eq!(
1143            [
1144                "127.0.0.1".parse::<IpAddr>().unwrap(),
1145                "127.0.0.2".parse::<IpAddr>().unwrap()
1146            ],
1147            config.get_hostaddrs(),
1148        );
1149
1150        assert_eq!(1, 1);
1151    }
1152
1153    #[test]
1154    fn test_invalid_hostaddr_parsing() {
1155        let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
1156        s.parse::<Config>().err().unwrap();
1157    }
1158}