1#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect_raw::connect_raw;
6#[cfg(not(target_arch = "wasm32"))]
7use crate::keepalive::KeepaliveConfig;
8#[cfg(feature = "runtime")]
9use crate::tls::MakeTlsConnect;
10use crate::tls::TlsConnect;
11#[cfg(feature = "runtime")]
12use crate::Socket;
13use crate::{Client, Connection, Error};
14use std::borrow::Cow;
15#[cfg(unix)]
16use std::ffi::OsStr;
17use std::net::IpAddr;
18use std::ops::Deref;
19#[cfg(unix)]
20use std::os::unix::ffi::OsStrExt;
21#[cfg(unix)]
22use std::path::{Path, PathBuf};
23use std::str;
24use std::str::FromStr;
25use std::time::Duration;
26use std::{error, fmt, iter, mem};
27use tokio::io::{AsyncRead, AsyncWrite};
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq)]
31#[non_exhaustive]
32pub enum TargetSessionAttrs {
33 Any,
35 ReadWrite,
37 ReadOnly,
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum SslMode {
45 Disable,
47 Prefer,
49 Require,
51}
52
53#[derive(Debug, Copy, Clone, PartialEq, Eq)]
55#[non_exhaustive]
56pub enum ChannelBinding {
57 Disable,
59 Prefer,
61 Require,
63}
64
65#[derive(Debug, Copy, Clone, PartialEq, Eq)]
67#[non_exhaustive]
68pub enum LoadBalanceHosts {
69 Disable,
71 Random,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum Host {
78 Tcp(String),
80 #[cfg(unix)]
84 Unix(PathBuf),
85}
86
87#[derive(Clone, PartialEq, Eq)]
194pub struct Config {
195 pub(crate) user: Option<String>,
196 pub(crate) password: Option<Vec<u8>>,
197 pub(crate) dbname: Option<String>,
198 pub(crate) options: Option<String>,
199 pub(crate) application_name: Option<String>,
200 pub(crate) ssl_mode: SslMode,
201 pub(crate) host: Vec<Host>,
202 pub(crate) hostaddr: Vec<IpAddr>,
203 pub(crate) port: Vec<u16>,
204 pub(crate) connect_timeout: Option<Duration>,
205 pub(crate) tcp_user_timeout: Option<Duration>,
206 pub(crate) keepalives: bool,
207 #[cfg(not(target_arch = "wasm32"))]
208 pub(crate) keepalive_config: KeepaliveConfig,
209 pub(crate) target_session_attrs: TargetSessionAttrs,
210 pub(crate) channel_binding: ChannelBinding,
211 pub(crate) load_balance_hosts: LoadBalanceHosts,
212}
213
214impl Default for Config {
215 fn default() -> Config {
216 Config::new()
217 }
218}
219
220impl Config {
221 pub fn new() -> Config {
223 Config {
224 user: None,
225 password: None,
226 dbname: None,
227 options: None,
228 application_name: None,
229 ssl_mode: SslMode::Prefer,
230 host: vec![],
231 hostaddr: vec![],
232 port: vec![],
233 connect_timeout: None,
234 tcp_user_timeout: None,
235 keepalives: true,
236 #[cfg(not(target_arch = "wasm32"))]
237 keepalive_config: KeepaliveConfig {
238 idle: Duration::from_secs(2 * 60 * 60),
239 interval: None,
240 retries: None,
241 },
242 target_session_attrs: TargetSessionAttrs::Any,
243 channel_binding: ChannelBinding::Prefer,
244 load_balance_hosts: LoadBalanceHosts::Disable,
245 }
246 }
247
248 pub fn user(&mut self, user: impl Into<String>) -> &mut Config {
252 self.user = Some(user.into());
253 self
254 }
255
256 pub fn get_user(&self) -> Option<&str> {
259 self.user.as_deref()
260 }
261
262 pub fn password<T>(&mut self, password: T) -> &mut Config
264 where
265 T: AsRef<[u8]>,
266 {
267 self.password = Some(password.as_ref().to_vec());
268 self
269 }
270
271 pub fn get_password(&self) -> Option<&[u8]> {
274 self.password.as_deref()
275 }
276
277 pub fn dbname(&mut self, dbname: impl Into<String>) -> &mut Config {
281 self.dbname = Some(dbname.into());
282 self
283 }
284
285 pub fn get_dbname(&self) -> Option<&str> {
288 self.dbname.as_deref()
289 }
290
291 pub fn options(&mut self, options: impl Into<String>) -> &mut Config {
293 self.options = Some(options.into());
294 self
295 }
296
297 pub fn get_options(&self) -> Option<&str> {
300 self.options.as_deref()
301 }
302
303 pub fn application_name(&mut self, application_name: impl Into<String>) -> &mut Config {
305 self.application_name = Some(application_name.into());
306 self
307 }
308
309 pub fn get_application_name(&self) -> Option<&str> {
312 self.application_name.as_deref()
313 }
314
315 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
319 self.ssl_mode = ssl_mode;
320 self
321 }
322
323 pub fn get_ssl_mode(&self) -> SslMode {
325 self.ssl_mode
326 }
327
328 pub fn host(&mut self, host: impl Into<String>) -> &mut Config {
334 let host = host.into();
335
336 #[cfg(unix)]
337 {
338 if host.starts_with('/') {
339 return self.host_path(host);
340 }
341 }
342
343 self.host.push(Host::Tcp(host));
344 self
345 }
346
347 pub fn get_hosts(&self) -> &[Host] {
349 &self.host
350 }
351
352 pub fn get_hostaddrs(&self) -> &[IpAddr] {
354 self.hostaddr.deref()
355 }
356
357 #[cfg(unix)]
361 pub fn host_path<T>(&mut self, host: T) -> &mut Config
362 where
363 T: AsRef<Path>,
364 {
365 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
366 self
367 }
368
369 pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
374 self.hostaddr.push(hostaddr);
375 self
376 }
377
378 pub fn port(&mut self, port: u16) -> &mut Config {
384 self.port.push(port);
385 self
386 }
387
388 pub fn get_ports(&self) -> &[u16] {
390 &self.port
391 }
392
393 pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
398 self.connect_timeout = Some(connect_timeout);
399 self
400 }
401
402 pub fn get_connect_timeout(&self) -> Option<&Duration> {
405 self.connect_timeout.as_ref()
406 }
407
408 pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
414 self.tcp_user_timeout = Some(tcp_user_timeout);
415 self
416 }
417
418 pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
421 self.tcp_user_timeout.as_ref()
422 }
423
424 pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
428 self.keepalives = keepalives;
429 self
430 }
431
432 pub fn get_keepalives(&self) -> bool {
434 self.keepalives
435 }
436
437 #[cfg(not(target_arch = "wasm32"))]
441 pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
442 self.keepalive_config.idle = keepalives_idle;
443 self
444 }
445
446 #[cfg(not(target_arch = "wasm32"))]
449 pub fn get_keepalives_idle(&self) -> Duration {
450 self.keepalive_config.idle
451 }
452
453 #[cfg(not(target_arch = "wasm32"))]
458 pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
459 self.keepalive_config.interval = Some(keepalives_interval);
460 self
461 }
462
463 #[cfg(not(target_arch = "wasm32"))]
465 pub fn get_keepalives_interval(&self) -> Option<Duration> {
466 self.keepalive_config.interval
467 }
468
469 #[cfg(not(target_arch = "wasm32"))]
473 pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
474 self.keepalive_config.retries = Some(keepalives_retries);
475 self
476 }
477
478 #[cfg(not(target_arch = "wasm32"))]
480 pub fn get_keepalives_retries(&self) -> Option<u32> {
481 self.keepalive_config.retries
482 }
483
484 pub fn target_session_attrs(
489 &mut self,
490 target_session_attrs: TargetSessionAttrs,
491 ) -> &mut Config {
492 self.target_session_attrs = target_session_attrs;
493 self
494 }
495
496 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
498 self.target_session_attrs
499 }
500
501 pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
505 self.channel_binding = channel_binding;
506 self
507 }
508
509 pub fn get_channel_binding(&self) -> ChannelBinding {
511 self.channel_binding
512 }
513
514 pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
518 self.load_balance_hosts = load_balance_hosts;
519 self
520 }
521
522 pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
524 self.load_balance_hosts
525 }
526
527 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
528 match key {
529 "user" => {
530 self.user(value);
531 }
532 "password" => {
533 self.password(value);
534 }
535 "dbname" => {
536 self.dbname(value);
537 }
538 "options" => {
539 self.options(value);
540 }
541 "application_name" => {
542 self.application_name(value);
543 }
544 "sslmode" => {
545 let mode = match value {
546 "disable" => SslMode::Disable,
547 "prefer" => SslMode::Prefer,
548 "require" => SslMode::Require,
549 _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
550 };
551 self.ssl_mode(mode);
552 }
553 "host" => {
554 for host in value.split(',') {
555 self.host(host);
556 }
557 }
558 "hostaddr" => {
559 for hostaddr in value.split(',') {
560 let addr = hostaddr
561 .parse()
562 .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
563 self.hostaddr(addr);
564 }
565 }
566 "port" => {
567 for port in value.split(',') {
568 let port = if port.is_empty() {
569 5432
570 } else {
571 port.parse()
572 .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
573 };
574 self.port(port);
575 }
576 }
577 "connect_timeout" => {
578 let timeout = value
579 .parse::<i64>()
580 .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
581 if timeout > 0 {
582 self.connect_timeout(Duration::from_secs(timeout as u64));
583 }
584 }
585 "tcp_user_timeout" => {
586 let timeout = value
587 .parse::<i64>()
588 .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?;
589 if timeout > 0 {
590 self.tcp_user_timeout(Duration::from_secs(timeout as u64));
591 }
592 }
593 #[cfg(not(target_arch = "wasm32"))]
594 "keepalives" => {
595 let keepalives = value
596 .parse::<u64>()
597 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
598 self.keepalives(keepalives != 0);
599 }
600 #[cfg(not(target_arch = "wasm32"))]
601 "keepalives_idle" => {
602 let keepalives_idle = value
603 .parse::<i64>()
604 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
605 if keepalives_idle > 0 {
606 self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
607 }
608 }
609 #[cfg(not(target_arch = "wasm32"))]
610 "keepalives_interval" => {
611 let keepalives_interval = value.parse::<i64>().map_err(|_| {
612 Error::config_parse(Box::new(InvalidValue("keepalives_interval")))
613 })?;
614 if keepalives_interval > 0 {
615 self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
616 }
617 }
618 #[cfg(not(target_arch = "wasm32"))]
619 "keepalives_retries" => {
620 let keepalives_retries = value.parse::<u32>().map_err(|_| {
621 Error::config_parse(Box::new(InvalidValue("keepalives_retries")))
622 })?;
623 self.keepalives_retries(keepalives_retries);
624 }
625 "target_session_attrs" => {
626 let target_session_attrs = match value {
627 "any" => TargetSessionAttrs::Any,
628 "read-write" => TargetSessionAttrs::ReadWrite,
629 "read-only" => TargetSessionAttrs::ReadOnly,
630 _ => {
631 return Err(Error::config_parse(Box::new(InvalidValue(
632 "target_session_attrs",
633 ))));
634 }
635 };
636 self.target_session_attrs(target_session_attrs);
637 }
638 "channel_binding" => {
639 let channel_binding = match value {
640 "disable" => ChannelBinding::Disable,
641 "prefer" => ChannelBinding::Prefer,
642 "require" => ChannelBinding::Require,
643 _ => {
644 return Err(Error::config_parse(Box::new(InvalidValue(
645 "channel_binding",
646 ))))
647 }
648 };
649 self.channel_binding(channel_binding);
650 }
651 "load_balance_hosts" => {
652 let load_balance_hosts = match value {
653 "disable" => LoadBalanceHosts::Disable,
654 "random" => LoadBalanceHosts::Random,
655 _ => {
656 return Err(Error::config_parse(Box::new(InvalidValue(
657 "load_balance_hosts",
658 ))))
659 }
660 };
661 self.load_balance_hosts(load_balance_hosts);
662 }
663 key => {
664 return Err(Error::config_parse(Box::new(UnknownOption(
665 key.to_string(),
666 ))));
667 }
668 }
669
670 Ok(())
671 }
672
673 #[cfg(feature = "runtime")]
677 pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
678 where
679 T: MakeTlsConnect<Socket>,
680 {
681 connect(tls, self).await
682 }
683
684 pub async fn connect_raw<S, T>(
688 &self,
689 stream: S,
690 tls: T,
691 ) -> Result<(Client, Connection<S, T::Stream>), Error>
692 where
693 S: AsyncRead + AsyncWrite + Unpin,
694 T: TlsConnect<S>,
695 {
696 connect_raw(stream, tls, true, self).await
697 }
698}
699
700impl FromStr for Config {
701 type Err = Error;
702
703 fn from_str(s: &str) -> Result<Config, Error> {
704 match UrlParser::parse(s)? {
705 Some(config) => Ok(config),
706 None => Parser::parse(s),
707 }
708 }
709}
710
711impl fmt::Debug for Config {
713 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
714 struct Redaction {}
715 impl fmt::Debug for Redaction {
716 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
717 write!(f, "_")
718 }
719 }
720
721 let mut config_dbg = &mut f.debug_struct("Config");
722 config_dbg = config_dbg
723 .field("user", &self.user)
724 .field("password", &self.password.as_ref().map(|_| Redaction {}))
725 .field("dbname", &self.dbname)
726 .field("options", &self.options)
727 .field("application_name", &self.application_name)
728 .field("ssl_mode", &self.ssl_mode)
729 .field("host", &self.host)
730 .field("hostaddr", &self.hostaddr)
731 .field("port", &self.port)
732 .field("connect_timeout", &self.connect_timeout)
733 .field("tcp_user_timeout", &self.tcp_user_timeout)
734 .field("keepalives", &self.keepalives);
735
736 #[cfg(not(target_arch = "wasm32"))]
737 {
738 config_dbg = config_dbg
739 .field("keepalives_idle", &self.keepalive_config.idle)
740 .field("keepalives_interval", &self.keepalive_config.interval)
741 .field("keepalives_retries", &self.keepalive_config.retries);
742 }
743
744 config_dbg
745 .field("target_session_attrs", &self.target_session_attrs)
746 .field("channel_binding", &self.channel_binding)
747 .finish()
748 }
749}
750
751#[derive(Debug)]
752struct UnknownOption(String);
753
754impl fmt::Display for UnknownOption {
755 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
756 write!(fmt, "unknown option `{}`", self.0)
757 }
758}
759
760impl error::Error for UnknownOption {}
761
762#[derive(Debug)]
763struct InvalidValue(&'static str);
764
765impl fmt::Display for InvalidValue {
766 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
767 write!(fmt, "invalid value for option `{}`", self.0)
768 }
769}
770
771impl error::Error for InvalidValue {}
772
773struct Parser<'a> {
774 s: &'a str,
775 it: iter::Peekable<str::CharIndices<'a>>,
776}
777
778impl<'a> Parser<'a> {
779 fn parse(s: &'a str) -> Result<Config, Error> {
780 let mut parser = Parser {
781 s,
782 it: s.char_indices().peekable(),
783 };
784
785 let mut config = Config::new();
786
787 while let Some((key, value)) = parser.parameter()? {
788 config.param(key, &value)?;
789 }
790
791 Ok(config)
792 }
793
794 fn skip_ws(&mut self) {
795 self.take_while(char::is_whitespace);
796 }
797
798 fn take_while<F>(&mut self, f: F) -> &'a str
799 where
800 F: Fn(char) -> bool,
801 {
802 let start = match self.it.peek() {
803 Some(&(i, _)) => i,
804 None => return "",
805 };
806
807 loop {
808 match self.it.peek() {
809 Some(&(_, c)) if f(c) => {
810 self.it.next();
811 }
812 Some(&(i, _)) => return &self.s[start..i],
813 None => return &self.s[start..],
814 }
815 }
816 }
817
818 fn eat(&mut self, target: char) -> Result<(), Error> {
819 match self.it.next() {
820 Some((_, c)) if c == target => Ok(()),
821 Some((i, c)) => {
822 let m = format!(
823 "unexpected character at byte {}: expected `{}` but got `{}`",
824 i, target, c
825 );
826 Err(Error::config_parse(m.into()))
827 }
828 None => Err(Error::config_parse("unexpected EOF".into())),
829 }
830 }
831
832 fn eat_if(&mut self, target: char) -> bool {
833 match self.it.peek() {
834 Some(&(_, c)) if c == target => {
835 self.it.next();
836 true
837 }
838 _ => false,
839 }
840 }
841
842 fn keyword(&mut self) -> Option<&'a str> {
843 let s = self.take_while(|c| match c {
844 c if c.is_whitespace() => false,
845 '=' => false,
846 _ => true,
847 });
848
849 if s.is_empty() {
850 None
851 } else {
852 Some(s)
853 }
854 }
855
856 fn value(&mut self) -> Result<String, Error> {
857 let value = if self.eat_if('\'') {
858 let value = self.quoted_value()?;
859 self.eat('\'')?;
860 value
861 } else {
862 self.simple_value()?
863 };
864
865 Ok(value)
866 }
867
868 fn simple_value(&mut self) -> Result<String, Error> {
869 let mut value = String::new();
870
871 while let Some(&(_, c)) = self.it.peek() {
872 if c.is_whitespace() {
873 break;
874 }
875
876 self.it.next();
877 if c == '\\' {
878 if let Some((_, c2)) = self.it.next() {
879 value.push(c2);
880 }
881 } else {
882 value.push(c);
883 }
884 }
885
886 if value.is_empty() {
887 return Err(Error::config_parse("unexpected EOF".into()));
888 }
889
890 Ok(value)
891 }
892
893 fn quoted_value(&mut self) -> Result<String, Error> {
894 let mut value = String::new();
895
896 while let Some(&(_, c)) = self.it.peek() {
897 if c == '\'' {
898 return Ok(value);
899 }
900
901 self.it.next();
902 if c == '\\' {
903 if let Some((_, c2)) = self.it.next() {
904 value.push(c2);
905 }
906 } else {
907 value.push(c);
908 }
909 }
910
911 Err(Error::config_parse(
912 "unterminated quoted connection parameter value".into(),
913 ))
914 }
915
916 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
917 self.skip_ws();
918 let keyword = match self.keyword() {
919 Some(keyword) => keyword,
920 None => return Ok(None),
921 };
922 self.skip_ws();
923 self.eat('=')?;
924 self.skip_ws();
925 let value = self.value()?;
926
927 Ok(Some((keyword, value)))
928 }
929}
930
931struct UrlParser<'a> {
933 s: &'a str,
934 config: Config,
935}
936
937impl<'a> UrlParser<'a> {
938 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
939 let s = match Self::remove_url_prefix(s) {
940 Some(s) => s,
941 None => return Ok(None),
942 };
943
944 let mut parser = UrlParser {
945 s,
946 config: Config::new(),
947 };
948
949 parser.parse_credentials()?;
950 parser.parse_host()?;
951 parser.parse_path()?;
952 parser.parse_params()?;
953
954 Ok(Some(parser.config))
955 }
956
957 fn remove_url_prefix(s: &str) -> Option<&str> {
958 for prefix in &["postgres://", "postgresql://"] {
959 if let Some(stripped) = s.strip_prefix(prefix) {
960 return Some(stripped);
961 }
962 }
963
964 None
965 }
966
967 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
968 match self.s.find(end) {
969 Some(pos) => {
970 let (head, tail) = self.s.split_at(pos);
971 self.s = tail;
972 Some(head)
973 }
974 None => None,
975 }
976 }
977
978 fn take_all(&mut self) -> &'a str {
979 mem::take(&mut self.s)
980 }
981
982 fn eat_byte(&mut self) {
983 self.s = &self.s[1..];
984 }
985
986 fn parse_credentials(&mut self) -> Result<(), Error> {
987 let creds = match self.take_until(&['@']) {
988 Some(creds) => creds,
989 None => return Ok(()),
990 };
991 self.eat_byte();
992
993 let mut it = creds.splitn(2, ':');
994 let user = self.decode(it.next().unwrap())?;
995 self.config.user(user);
996
997 if let Some(password) = it.next() {
998 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
999 self.config.password(password);
1000 }
1001
1002 Ok(())
1003 }
1004
1005 fn parse_host(&mut self) -> Result<(), Error> {
1006 let host = match self.take_until(&['/', '?']) {
1007 Some(host) => host,
1008 None => self.take_all(),
1009 };
1010
1011 if host.is_empty() {
1012 return Ok(());
1013 }
1014
1015 for chunk in host.split(',') {
1016 let (host, port) = if chunk.starts_with('[') {
1017 let idx = match chunk.find(']') {
1018 Some(idx) => idx,
1019 None => return Err(Error::config_parse(InvalidValue("host").into())),
1020 };
1021
1022 let host = &chunk[1..idx];
1023 let remaining = &chunk[idx + 1..];
1024 let port = if let Some(port) = remaining.strip_prefix(':') {
1025 Some(port)
1026 } else if remaining.is_empty() {
1027 None
1028 } else {
1029 return Err(Error::config_parse(InvalidValue("host").into()));
1030 };
1031
1032 (host, port)
1033 } else {
1034 let mut it = chunk.splitn(2, ':');
1035 (it.next().unwrap(), it.next())
1036 };
1037
1038 self.host_param(host)?;
1039 let port = self.decode(port.unwrap_or("5432"))?;
1040 self.config.param("port", &port)?;
1041 }
1042
1043 Ok(())
1044 }
1045
1046 fn parse_path(&mut self) -> Result<(), Error> {
1047 if !self.s.starts_with('/') {
1048 return Ok(());
1049 }
1050 self.eat_byte();
1051
1052 let dbname = match self.take_until(&['?']) {
1053 Some(dbname) => dbname,
1054 None => self.take_all(),
1055 };
1056
1057 if !dbname.is_empty() {
1058 self.config.dbname(self.decode(dbname)?);
1059 }
1060
1061 Ok(())
1062 }
1063
1064 fn parse_params(&mut self) -> Result<(), Error> {
1065 if !self.s.starts_with('?') {
1066 return Ok(());
1067 }
1068 self.eat_byte();
1069
1070 while !self.s.is_empty() {
1071 let key = match self.take_until(&['=']) {
1072 Some(key) => self.decode(key)?,
1073 None => return Err(Error::config_parse("unterminated parameter".into())),
1074 };
1075 self.eat_byte();
1076
1077 let value = match self.take_until(&['&']) {
1078 Some(value) => {
1079 self.eat_byte();
1080 value
1081 }
1082 None => self.take_all(),
1083 };
1084
1085 if key == "host" {
1086 self.host_param(value)?;
1087 } else {
1088 let value = self.decode(value)?;
1089 self.config.param(&key, &value)?;
1090 }
1091 }
1092
1093 Ok(())
1094 }
1095
1096 #[cfg(unix)]
1097 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1098 let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
1099 if decoded.first() == Some(&b'/') {
1100 self.config.host_path(OsStr::from_bytes(&decoded));
1101 } else {
1102 let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
1103 self.config.host(decoded);
1104 }
1105
1106 Ok(())
1107 }
1108
1109 #[cfg(not(unix))]
1110 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1111 let s = self.decode(s)?;
1112 self.config.param("host", &s)
1113 }
1114
1115 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
1116 percent_encoding::percent_decode(s.as_bytes())
1117 .decode_utf8()
1118 .map_err(|e| Error::config_parse(e.into()))
1119 }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124 use std::net::IpAddr;
1125
1126 use crate::{config::Host, Config};
1127
1128 #[test]
1129 fn test_simple_parsing() {
1130 let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
1131 let config = s.parse::<Config>().unwrap();
1132 assert_eq!(Some("pass_user"), config.get_user());
1133 assert_eq!(Some("postgres"), config.get_dbname());
1134 assert_eq!(
1135 [
1136 Host::Tcp("host1".to_string()),
1137 Host::Tcp("host2".to_string())
1138 ],
1139 config.get_hosts(),
1140 );
1141
1142 assert_eq!(
1143 [
1144 "127.0.0.1".parse::<IpAddr>().unwrap(),
1145 "127.0.0.2".parse::<IpAddr>().unwrap()
1146 ],
1147 config.get_hostaddrs(),
1148 );
1149
1150 assert_eq!(1, 1);
1151 }
1152
1153 #[test]
1154 fn test_invalid_hostaddr_parsing() {
1155 let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
1156 s.parse::<Config>().err().unwrap();
1157 }
1158}