reqwest/
connect.rs

1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::client::connect::{Connected, Connection};
6use hyper::service::Service;
7#[cfg(feature = "native-tls-crate")]
8use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11use pin_project_lite::pin_project;
12use std::future::Future;
13use std::io::{self, IoSlice};
14use std::net::IpAddr;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::Duration;
19
20#[cfg(feature = "default-tls")]
21use self::native_tls_conn::NativeTlsConn;
22#[cfg(feature = "__rustls")]
23use self::rustls_tls_conn::RustlsTlsConn;
24use crate::dns::DynResolver;
25use crate::error::BoxError;
26use crate::proxy::{Proxy, ProxyScheme};
27
28pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;
29
30#[derive(Clone)]
31pub(crate) struct Connector {
32    inner: Inner,
33    proxies: Arc<Vec<Proxy>>,
34    verbose: verbose::Wrapper,
35    timeout: Option<Duration>,
36    #[cfg(feature = "__tls")]
37    nodelay: bool,
38    #[cfg(feature = "__tls")]
39    tls_info: bool,
40    #[cfg(feature = "__tls")]
41    user_agent: Option<HeaderValue>,
42}
43
44#[derive(Clone)]
45enum Inner {
46    #[cfg(not(feature = "__tls"))]
47    Http(HttpConnector),
48    #[cfg(feature = "default-tls")]
49    DefaultTls(HttpConnector, TlsConnector),
50    #[cfg(feature = "__rustls")]
51    RustlsTls {
52        http: HttpConnector,
53        tls: Arc<rustls::ClientConfig>,
54        tls_proxy: Arc<rustls::ClientConfig>,
55    },
56}
57
58impl Connector {
59    #[cfg(not(feature = "__tls"))]
60    pub(crate) fn new<T>(
61        mut http: HttpConnector,
62        proxies: Arc<Vec<Proxy>>,
63        local_addr: T,
64        nodelay: bool,
65    ) -> Connector
66    where
67        T: Into<Option<IpAddr>>,
68    {
69        http.set_local_address(local_addr.into());
70        http.set_nodelay(nodelay);
71
72        Connector {
73            inner: Inner::Http(http),
74            verbose: verbose::OFF,
75            proxies,
76            timeout: None,
77        }
78    }
79
80    #[cfg(feature = "default-tls")]
81    pub(crate) fn new_default_tls<T>(
82        http: HttpConnector,
83        tls: TlsConnectorBuilder,
84        proxies: Arc<Vec<Proxy>>,
85        user_agent: Option<HeaderValue>,
86        local_addr: T,
87        nodelay: bool,
88        tls_info: bool,
89    ) -> crate::Result<Connector>
90    where
91        T: Into<Option<IpAddr>>,
92    {
93        let tls = tls.build().map_err(crate::error::builder)?;
94        Ok(Self::from_built_default_tls(
95            http, tls, proxies, user_agent, local_addr, nodelay, tls_info,
96        ))
97    }
98
99    #[cfg(feature = "default-tls")]
100    pub(crate) fn from_built_default_tls<T>(
101        mut http: HttpConnector,
102        tls: TlsConnector,
103        proxies: Arc<Vec<Proxy>>,
104        user_agent: Option<HeaderValue>,
105        local_addr: T,
106        nodelay: bool,
107        tls_info: bool,
108    ) -> Connector
109    where
110        T: Into<Option<IpAddr>>,
111    {
112        http.set_local_address(local_addr.into());
113        http.set_nodelay(nodelay);
114        http.enforce_http(false);
115
116        Connector {
117            inner: Inner::DefaultTls(http, tls),
118            proxies,
119            verbose: verbose::OFF,
120            timeout: None,
121            nodelay,
122            tls_info,
123            user_agent,
124        }
125    }
126
127    #[cfg(feature = "__rustls")]
128    pub(crate) fn new_rustls_tls<T>(
129        mut http: HttpConnector,
130        tls: rustls::ClientConfig,
131        proxies: Arc<Vec<Proxy>>,
132        user_agent: Option<HeaderValue>,
133        local_addr: T,
134        nodelay: bool,
135        tls_info: bool,
136    ) -> Connector
137    where
138        T: Into<Option<IpAddr>>,
139    {
140        http.set_local_address(local_addr.into());
141        http.set_nodelay(nodelay);
142        http.enforce_http(false);
143
144        let (tls, tls_proxy) = if proxies.is_empty() {
145            let tls = Arc::new(tls);
146            (tls.clone(), tls)
147        } else {
148            let mut tls_proxy = tls.clone();
149            tls_proxy.alpn_protocols.clear();
150            (Arc::new(tls), Arc::new(tls_proxy))
151        };
152
153        Connector {
154            inner: Inner::RustlsTls {
155                http,
156                tls,
157                tls_proxy,
158            },
159            proxies,
160            verbose: verbose::OFF,
161            timeout: None,
162            nodelay,
163            tls_info,
164            user_agent,
165        }
166    }
167
168    pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
169        self.timeout = timeout;
170    }
171
172    pub(crate) fn set_verbose(&mut self, enabled: bool) {
173        self.verbose.0 = enabled;
174    }
175
176    #[cfg(feature = "socks")]
177    async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
178        let dns = match proxy {
179            ProxyScheme::Socks5 {
180                remote_dns: false, ..
181            } => socks::DnsResolve::Local,
182            ProxyScheme::Socks5 {
183                remote_dns: true, ..
184            } => socks::DnsResolve::Proxy,
185            ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
186                unreachable!("connect_socks is only called for socks proxies");
187            }
188        };
189
190        match &self.inner {
191            #[cfg(feature = "default-tls")]
192            Inner::DefaultTls(_http, tls) => {
193                if dst.scheme() == Some(&Scheme::HTTPS) {
194                    let host = dst.host().ok_or("no host in url")?.to_string();
195                    let conn = socks::connect(proxy, dst, dns).await?;
196                    let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
197                    let io = tls_connector.connect(&host, conn).await?;
198                    return Ok(Conn {
199                        inner: self.verbose.wrap(NativeTlsConn { inner: io }),
200                        is_proxy: false,
201                        tls_info: self.tls_info,
202                    });
203                }
204            }
205            #[cfg(feature = "__rustls")]
206            Inner::RustlsTls { tls_proxy, .. } => {
207                if dst.scheme() == Some(&Scheme::HTTPS) {
208                    use std::convert::TryFrom;
209                    use tokio_rustls::TlsConnector as RustlsConnector;
210
211                    let tls = tls_proxy.clone();
212                    let host = dst.host().ok_or("no host in url")?.to_string();
213                    let conn = socks::connect(proxy, dst, dns).await?;
214                    let server_name = rustls::ServerName::try_from(host.as_str())
215                        .map_err(|_| "Invalid Server Name")?;
216                    let io = RustlsConnector::from(tls)
217                        .connect(server_name, conn)
218                        .await?;
219                    return Ok(Conn {
220                        inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
221                        is_proxy: false,
222                        tls_info: false,
223                    });
224                }
225            }
226            #[cfg(not(feature = "__tls"))]
227            Inner::Http(_) => (),
228        }
229
230        socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
231            inner: self.verbose.wrap(tcp),
232            is_proxy: false,
233            tls_info: false,
234        })
235    }
236
237    async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
238        match self.inner {
239            #[cfg(not(feature = "__tls"))]
240            Inner::Http(mut http) => {
241                let io = http.call(dst).await?;
242                Ok(Conn {
243                    inner: self.verbose.wrap(io),
244                    is_proxy,
245                    tls_info: false,
246                })
247            }
248            #[cfg(feature = "default-tls")]
249            Inner::DefaultTls(http, tls) => {
250                let mut http = http.clone();
251
252                // Disable Nagle's algorithm for TLS handshake
253                //
254                // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
255                if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
256                    http.set_nodelay(true);
257                }
258
259                let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
260                let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
261                let io = http.call(dst).await?;
262
263                if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
264                    if !self.nodelay {
265                        stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
266                    }
267                    Ok(Conn {
268                        inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
269                        is_proxy,
270                        tls_info: self.tls_info,
271                    })
272                } else {
273                    Ok(Conn {
274                        inner: self.verbose.wrap(io),
275                        is_proxy,
276                        tls_info: false,
277                    })
278                }
279            }
280            #[cfg(feature = "__rustls")]
281            Inner::RustlsTls { http, tls, .. } => {
282                let mut http = http.clone();
283
284                // Disable Nagle's algorithm for TLS handshake
285                //
286                // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
287                if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
288                    http.set_nodelay(true);
289                }
290
291                let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
292                let io = http.call(dst).await?;
293
294                if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
295                    if !self.nodelay {
296                        let (io, _) = stream.get_ref();
297                        io.set_nodelay(false)?;
298                    }
299                    Ok(Conn {
300                        inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
301                        is_proxy,
302                        tls_info: self.tls_info,
303                    })
304                } else {
305                    Ok(Conn {
306                        inner: self.verbose.wrap(io),
307                        is_proxy,
308                        tls_info: false,
309                    })
310                }
311            }
312        }
313    }
314
315    async fn connect_via_proxy(
316        self,
317        dst: Uri,
318        proxy_scheme: ProxyScheme,
319    ) -> Result<Conn, BoxError> {
320        log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");
321
322        let (proxy_dst, _auth) = match proxy_scheme {
323            ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
324            ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
325            #[cfg(feature = "socks")]
326            ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
327        };
328
329        #[cfg(feature = "__tls")]
330        let auth = _auth;
331
332        match &self.inner {
333            #[cfg(feature = "default-tls")]
334            Inner::DefaultTls(http, tls) => {
335                if dst.scheme() == Some(&Scheme::HTTPS) {
336                    let host = dst.host().to_owned();
337                    let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
338                    let http = http.clone();
339                    let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
340                    let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
341                    let conn = http.call(proxy_dst).await?;
342                    log::trace!("tunneling HTTPS over proxy");
343                    let tunneled = tunnel(
344                        conn,
345                        host.ok_or("no host in url")?.to_string(),
346                        port,
347                        self.user_agent.clone(),
348                        auth,
349                    )
350                    .await?;
351                    let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
352                    let io = tls_connector
353                        .connect(host.ok_or("no host in url")?, tunneled)
354                        .await?;
355                    return Ok(Conn {
356                        inner: self.verbose.wrap(NativeTlsConn { inner: io }),
357                        is_proxy: false,
358                        tls_info: false,
359                    });
360                }
361            }
362            #[cfg(feature = "__rustls")]
363            Inner::RustlsTls {
364                http,
365                tls,
366                tls_proxy,
367            } => {
368                if dst.scheme() == Some(&Scheme::HTTPS) {
369                    use rustls::ServerName;
370                    use std::convert::TryFrom;
371                    use tokio_rustls::TlsConnector as RustlsConnector;
372
373                    let host = dst.host().ok_or("no host in url")?.to_string();
374                    let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
375                    let http = http.clone();
376                    let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
377                    let tls = tls.clone();
378                    let conn = http.call(proxy_dst).await?;
379                    log::trace!("tunneling HTTPS over proxy");
380                    let maybe_server_name =
381                        ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name");
382                    let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
383                    let server_name = maybe_server_name?;
384                    let io = RustlsConnector::from(tls)
385                        .connect(server_name, tunneled)
386                        .await?;
387
388                    return Ok(Conn {
389                        inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
390                        is_proxy: false,
391                        tls_info: false,
392                    });
393                }
394            }
395            #[cfg(not(feature = "__tls"))]
396            Inner::Http(_) => (),
397        }
398
399        self.connect_with_maybe_proxy(proxy_dst, true).await
400    }
401
402    pub fn set_keepalive(&mut self, dur: Option<Duration>) {
403        match &mut self.inner {
404            #[cfg(feature = "default-tls")]
405            Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
406            #[cfg(feature = "__rustls")]
407            Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
408            #[cfg(not(feature = "__tls"))]
409            Inner::Http(http) => http.set_keepalive(dur),
410        }
411    }
412}
413
414fn into_uri(scheme: Scheme, host: Authority) -> Uri {
415    // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`?
416    http::Uri::builder()
417        .scheme(scheme)
418        .authority(host)
419        .path_and_query(http::uri::PathAndQuery::from_static("/"))
420        .build()
421        .expect("scheme and authority is valid Uri")
422}
423
424async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
425where
426    F: Future<Output = Result<T, BoxError>>,
427{
428    if let Some(to) = timeout {
429        match tokio::time::timeout(to, f).await {
430            Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
431            Ok(Ok(try_res)) => Ok(try_res),
432            Ok(Err(e)) => Err(e),
433        }
434    } else {
435        f.await
436    }
437}
438
439impl Service<Uri> for Connector {
440    type Response = Conn;
441    type Error = BoxError;
442    type Future = Connecting;
443
444    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
445        Poll::Ready(Ok(()))
446    }
447
448    fn call(&mut self, dst: Uri) -> Self::Future {
449        log::debug!("starting new connection: {dst:?}");
450        let timeout = self.timeout;
451        for prox in self.proxies.iter() {
452            if let Some(proxy_scheme) = prox.intercept(&dst) {
453                return Box::pin(with_timeout(
454                    self.clone().connect_via_proxy(dst, proxy_scheme),
455                    timeout,
456                ));
457            }
458        }
459
460        Box::pin(with_timeout(
461            self.clone().connect_with_maybe_proxy(dst, false),
462            timeout,
463        ))
464    }
465}
466
467#[cfg(feature = "__tls")]
468trait TlsInfoFactory {
469    fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
470}
471
472#[cfg(feature = "__tls")]
473impl TlsInfoFactory for tokio::net::TcpStream {
474    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
475        None
476    }
477}
478
479#[cfg(feature = "default-tls")]
480impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<tokio::net::TcpStream> {
481    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
482        match self {
483            hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
484            hyper_tls::MaybeHttpsStream::Http(_) => None,
485        }
486    }
487}
488
489#[cfg(feature = "default-tls")]
490impl TlsInfoFactory for hyper_tls::TlsStream<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> {
491    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
492        let peer_certificate = self
493            .get_ref()
494            .peer_certificate()
495            .ok()
496            .flatten()
497            .and_then(|c| c.to_der().ok());
498        Some(crate::tls::TlsInfo { peer_certificate })
499    }
500}
501
502#[cfg(feature = "default-tls")]
503impl TlsInfoFactory for tokio_native_tls::TlsStream<tokio::net::TcpStream> {
504    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
505        let peer_certificate = self
506            .get_ref()
507            .peer_certificate()
508            .ok()
509            .flatten()
510            .and_then(|c| c.to_der().ok());
511        Some(crate::tls::TlsInfo { peer_certificate })
512    }
513}
514
515#[cfg(feature = "__rustls")]
516impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream> {
517    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
518        match self {
519            hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
520            hyper_rustls::MaybeHttpsStream::Http(_) => None,
521        }
522    }
523}
524
525#[cfg(feature = "__rustls")]
526impl TlsInfoFactory for tokio_rustls::TlsStream<tokio::net::TcpStream> {
527    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
528        let peer_certificate = self
529            .get_ref()
530            .1
531            .peer_certificates()
532            .and_then(|certs| certs.first())
533            .map(|c| c.0.clone());
534        Some(crate::tls::TlsInfo { peer_certificate })
535    }
536}
537
538#[cfg(feature = "__rustls")]
539impl TlsInfoFactory
540    for tokio_rustls::client::TlsStream<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>>
541{
542    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
543        let peer_certificate = self
544            .get_ref()
545            .1
546            .peer_certificates()
547            .and_then(|certs| certs.first())
548            .map(|c| c.0.clone());
549        Some(crate::tls::TlsInfo { peer_certificate })
550    }
551}
552
553#[cfg(feature = "__rustls")]
554impl TlsInfoFactory for tokio_rustls::client::TlsStream<tokio::net::TcpStream> {
555    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
556        let peer_certificate = self
557            .get_ref()
558            .1
559            .peer_certificates()
560            .and_then(|certs| certs.first())
561            .map(|c| c.0.clone());
562        Some(crate::tls::TlsInfo { peer_certificate })
563    }
564}
565
566pub(crate) trait AsyncConn:
567    AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static
568{
569}
570
571impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
572
573#[cfg(feature = "__tls")]
574trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
575#[cfg(not(feature = "__tls"))]
576trait AsyncConnWithInfo: AsyncConn {}
577
578#[cfg(feature = "__tls")]
579impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {}
580#[cfg(not(feature = "__tls"))]
581impl<T: AsyncConn> AsyncConnWithInfo for T {}
582
583type BoxConn = Box<dyn AsyncConnWithInfo>;
584
585pin_project! {
586    /// Note: the `is_proxy` member means *is plain text HTTP proxy*.
587    /// This tells hyper whether the URI should be written in
588    /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
589    /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
590    pub(crate) struct Conn {
591        #[pin]
592        inner: BoxConn,
593        is_proxy: bool,
594        // Only needed for __tls, but #[cfg()] on fields breaks pin_project!
595        tls_info: bool,
596    }
597}
598
599impl Connection for Conn {
600    fn connected(&self) -> Connected {
601        let connected = self.inner.connected().proxy(self.is_proxy);
602        #[cfg(feature = "__tls")]
603        if self.tls_info {
604            if let Some(tls_info) = self.inner.tls_info() {
605                connected.extra(tls_info)
606            } else {
607                connected
608            }
609        } else {
610            connected
611        }
612        #[cfg(not(feature = "__tls"))]
613        connected
614    }
615}
616
617impl AsyncRead for Conn {
618    fn poll_read(
619        self: Pin<&mut Self>,
620        cx: &mut Context,
621        buf: &mut ReadBuf<'_>,
622    ) -> Poll<io::Result<()>> {
623        let this = self.project();
624        AsyncRead::poll_read(this.inner, cx, buf)
625    }
626}
627
628impl AsyncWrite for Conn {
629    fn poll_write(
630        self: Pin<&mut Self>,
631        cx: &mut Context,
632        buf: &[u8],
633    ) -> Poll<Result<usize, io::Error>> {
634        let this = self.project();
635        AsyncWrite::poll_write(this.inner, cx, buf)
636    }
637
638    fn poll_write_vectored(
639        self: Pin<&mut Self>,
640        cx: &mut Context<'_>,
641        bufs: &[IoSlice<'_>],
642    ) -> Poll<Result<usize, io::Error>> {
643        let this = self.project();
644        AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
645    }
646
647    fn is_write_vectored(&self) -> bool {
648        self.inner.is_write_vectored()
649    }
650
651    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
652        let this = self.project();
653        AsyncWrite::poll_flush(this.inner, cx)
654    }
655
656    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
657        let this = self.project();
658        AsyncWrite::poll_shutdown(this.inner, cx)
659    }
660}
661
662pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
663
664#[cfg(feature = "__tls")]
665async fn tunnel<T>(
666    mut conn: T,
667    host: String,
668    port: u16,
669    user_agent: Option<HeaderValue>,
670    auth: Option<HeaderValue>,
671) -> Result<T, BoxError>
672where
673    T: AsyncRead + AsyncWrite + Unpin,
674{
675    use tokio::io::{AsyncReadExt, AsyncWriteExt};
676
677    let mut buf = format!(
678        "\
679         CONNECT {host}:{port} HTTP/1.1\r\n\
680         Host: {host}:{port}\r\n\
681         "
682    )
683    .into_bytes();
684
685    // user-agent
686    if let Some(user_agent) = user_agent {
687        buf.extend_from_slice(b"User-Agent: ");
688        buf.extend_from_slice(user_agent.as_bytes());
689        buf.extend_from_slice(b"\r\n");
690    }
691
692    // proxy-authorization
693    if let Some(value) = auth {
694        log::debug!("tunnel to {host}:{port} using basic auth");
695        buf.extend_from_slice(b"Proxy-Authorization: ");
696        buf.extend_from_slice(value.as_bytes());
697        buf.extend_from_slice(b"\r\n");
698    }
699
700    // headers end
701    buf.extend_from_slice(b"\r\n");
702
703    conn.write_all(&buf).await?;
704
705    let mut buf = [0; 8192];
706    let mut pos = 0;
707
708    loop {
709        let n = conn.read(&mut buf[pos..]).await?;
710
711        if n == 0 {
712            return Err(tunnel_eof());
713        }
714        pos += n;
715
716        let recvd = &buf[..pos];
717        if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
718            if recvd.ends_with(b"\r\n\r\n") {
719                return Ok(conn);
720            }
721            if pos == buf.len() {
722                return Err("proxy headers too long for tunnel".into());
723            }
724        // else read more
725        } else if recvd.starts_with(b"HTTP/1.1 407") {
726            return Err("proxy authentication required".into());
727        } else {
728            return Err("unsuccessful tunnel".into());
729        }
730    }
731}
732
733#[cfg(feature = "__tls")]
734fn tunnel_eof() -> BoxError {
735    "unexpected eof while tunneling".into()
736}
737
738#[cfg(feature = "default-tls")]
739mod native_tls_conn {
740    use super::TlsInfoFactory;
741    use hyper::client::connect::{Connected, Connection};
742    use pin_project_lite::pin_project;
743    use std::{
744        io::{self, IoSlice},
745        pin::Pin,
746        task::{Context, Poll},
747    };
748    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
749    use tokio_native_tls::TlsStream;
750
751    pin_project! {
752        pub(super) struct NativeTlsConn<T> {
753            #[pin] pub(super) inner: TlsStream<T>,
754        }
755    }
756
757    impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
758        #[cfg(feature = "native-tls-alpn")]
759        fn connected(&self) -> Connected {
760            match self.inner.get_ref().negotiated_alpn().ok() {
761                Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self
762                    .inner
763                    .get_ref()
764                    .get_ref()
765                    .get_ref()
766                    .connected()
767                    .negotiated_h2(),
768                _ => self.inner.get_ref().get_ref().get_ref().connected(),
769            }
770        }
771
772        #[cfg(not(feature = "native-tls-alpn"))]
773        fn connected(&self) -> Connected {
774            self.inner.get_ref().get_ref().get_ref().connected()
775        }
776    }
777
778    impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> {
779        fn poll_read(
780            self: Pin<&mut Self>,
781            cx: &mut Context,
782            buf: &mut ReadBuf<'_>,
783        ) -> Poll<tokio::io::Result<()>> {
784            let this = self.project();
785            AsyncRead::poll_read(this.inner, cx, buf)
786        }
787    }
788
789    impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
790        fn poll_write(
791            self: Pin<&mut Self>,
792            cx: &mut Context,
793            buf: &[u8],
794        ) -> Poll<Result<usize, tokio::io::Error>> {
795            let this = self.project();
796            AsyncWrite::poll_write(this.inner, cx, buf)
797        }
798
799        fn poll_write_vectored(
800            self: Pin<&mut Self>,
801            cx: &mut Context<'_>,
802            bufs: &[IoSlice<'_>],
803        ) -> Poll<Result<usize, io::Error>> {
804            let this = self.project();
805            AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
806        }
807
808        fn is_write_vectored(&self) -> bool {
809            self.inner.is_write_vectored()
810        }
811
812        fn poll_flush(
813            self: Pin<&mut Self>,
814            cx: &mut Context,
815        ) -> Poll<Result<(), tokio::io::Error>> {
816            let this = self.project();
817            AsyncWrite::poll_flush(this.inner, cx)
818        }
819
820        fn poll_shutdown(
821            self: Pin<&mut Self>,
822            cx: &mut Context,
823        ) -> Poll<Result<(), tokio::io::Error>> {
824            let this = self.project();
825            AsyncWrite::poll_shutdown(this.inner, cx)
826        }
827    }
828
829    impl TlsInfoFactory for NativeTlsConn<tokio::net::TcpStream> {
830        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
831            self.inner.tls_info()
832        }
833    }
834
835    impl TlsInfoFactory for NativeTlsConn<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> {
836        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
837            self.inner.tls_info()
838        }
839    }
840}
841
842#[cfg(feature = "__rustls")]
843mod rustls_tls_conn {
844    use super::TlsInfoFactory;
845    use hyper::client::connect::{Connected, Connection};
846    use pin_project_lite::pin_project;
847    use std::{
848        io::{self, IoSlice},
849        pin::Pin,
850        task::{Context, Poll},
851    };
852    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
853    use tokio_rustls::client::TlsStream;
854
855    pin_project! {
856        pub(super) struct RustlsTlsConn<T> {
857            #[pin] pub(super) inner: TlsStream<T>,
858        }
859    }
860
861    impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
862        fn connected(&self) -> Connected {
863            if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") {
864                self.inner.get_ref().0.connected().negotiated_h2()
865            } else {
866                self.inner.get_ref().0.connected()
867            }
868        }
869    }
870
871    impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
872        fn poll_read(
873            self: Pin<&mut Self>,
874            cx: &mut Context,
875            buf: &mut ReadBuf<'_>,
876        ) -> Poll<tokio::io::Result<()>> {
877            let this = self.project();
878            AsyncRead::poll_read(this.inner, cx, buf)
879        }
880    }
881
882    impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
883        fn poll_write(
884            self: Pin<&mut Self>,
885            cx: &mut Context,
886            buf: &[u8],
887        ) -> Poll<Result<usize, tokio::io::Error>> {
888            let this = self.project();
889            AsyncWrite::poll_write(this.inner, cx, buf)
890        }
891
892        fn poll_write_vectored(
893            self: Pin<&mut Self>,
894            cx: &mut Context<'_>,
895            bufs: &[IoSlice<'_>],
896        ) -> Poll<Result<usize, io::Error>> {
897            let this = self.project();
898            AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
899        }
900
901        fn is_write_vectored(&self) -> bool {
902            self.inner.is_write_vectored()
903        }
904
905        fn poll_flush(
906            self: Pin<&mut Self>,
907            cx: &mut Context,
908        ) -> Poll<Result<(), tokio::io::Error>> {
909            let this = self.project();
910            AsyncWrite::poll_flush(this.inner, cx)
911        }
912
913        fn poll_shutdown(
914            self: Pin<&mut Self>,
915            cx: &mut Context,
916        ) -> Poll<Result<(), tokio::io::Error>> {
917            let this = self.project();
918            AsyncWrite::poll_shutdown(this.inner, cx)
919        }
920    }
921
922    impl TlsInfoFactory for RustlsTlsConn<tokio::net::TcpStream> {
923        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
924            self.inner.tls_info()
925        }
926    }
927
928    impl TlsInfoFactory for RustlsTlsConn<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>> {
929        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
930            self.inner.tls_info()
931        }
932    }
933}
934
935#[cfg(feature = "socks")]
936mod socks {
937    use std::io;
938    use std::net::ToSocketAddrs;
939
940    use http::Uri;
941    use tokio::net::TcpStream;
942    use tokio_socks::tcp::Socks5Stream;
943
944    use super::{BoxError, Scheme};
945    use crate::proxy::ProxyScheme;
946
947    pub(super) enum DnsResolve {
948        Local,
949        Proxy,
950    }
951
952    pub(super) async fn connect(
953        proxy: ProxyScheme,
954        dst: Uri,
955        dns: DnsResolve,
956    ) -> Result<TcpStream, BoxError> {
957        let https = dst.scheme() == Some(&Scheme::HTTPS);
958        let original_host = dst
959            .host()
960            .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
961        let mut host = original_host.to_owned();
962        let port = match dst.port() {
963            Some(p) => p.as_u16(),
964            None if https => 443u16,
965            _ => 80u16,
966        };
967
968        if let DnsResolve::Local = dns {
969            let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
970            if let Some(new_target) = maybe_new_target {
971                host = new_target.ip().to_string();
972            }
973        }
974
975        let (socket_addr, auth) = match proxy {
976            ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth),
977            _ => unreachable!(),
978        };
979
980        // Get a Tokio TcpStream
981        let stream = if let Some((username, password)) = auth {
982            Socks5Stream::connect_with_password(
983                socket_addr,
984                (host.as_str(), port),
985                &username,
986                &password,
987            )
988            .await
989            .map_err(|e| format!("socks connect error: {e}"))?
990        } else {
991            Socks5Stream::connect(socket_addr, (host.as_str(), port))
992                .await
993                .map_err(|e| format!("socks connect error: {e}"))?
994        };
995
996        Ok(stream.into_inner())
997    }
998}
999
1000mod verbose {
1001    use hyper::client::connect::{Connected, Connection};
1002    use std::cmp::min;
1003    use std::fmt;
1004    use std::io::{self, IoSlice};
1005    use std::pin::Pin;
1006    use std::task::{Context, Poll};
1007    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1008
1009    pub(super) const OFF: Wrapper = Wrapper(false);
1010
1011    #[derive(Clone, Copy)]
1012    pub(super) struct Wrapper(pub(super) bool);
1013
1014    impl Wrapper {
1015        pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn {
1016            if self.0 && log::log_enabled!(log::Level::Trace) {
1017                Box::new(Verbose {
1018                    // truncate is fine
1019                    id: crate::util::fast_random() as u32,
1020                    inner: conn,
1021                })
1022            } else {
1023                Box::new(conn)
1024            }
1025        }
1026    }
1027
1028    struct Verbose<T> {
1029        id: u32,
1030        inner: T,
1031    }
1032
1033    impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> {
1034        fn connected(&self) -> Connected {
1035            self.inner.connected()
1036        }
1037    }
1038
1039    impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> {
1040        fn poll_read(
1041            mut self: Pin<&mut Self>,
1042            cx: &mut Context,
1043            buf: &mut ReadBuf<'_>,
1044        ) -> Poll<std::io::Result<()>> {
1045            match Pin::new(&mut self.inner).poll_read(cx, buf) {
1046                Poll::Ready(Ok(())) => {
1047                    log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
1048                    Poll::Ready(Ok(()))
1049                }
1050                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1051                Poll::Pending => Poll::Pending,
1052            }
1053        }
1054    }
1055
1056    impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> {
1057        fn poll_write(
1058            mut self: Pin<&mut Self>,
1059            cx: &mut Context,
1060            buf: &[u8],
1061        ) -> Poll<Result<usize, std::io::Error>> {
1062            match Pin::new(&mut self.inner).poll_write(cx, buf) {
1063                Poll::Ready(Ok(n)) => {
1064                    log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1065                    Poll::Ready(Ok(n))
1066                }
1067                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1068                Poll::Pending => Poll::Pending,
1069            }
1070        }
1071
1072        fn poll_write_vectored(
1073            mut self: Pin<&mut Self>,
1074            cx: &mut Context<'_>,
1075            bufs: &[IoSlice<'_>],
1076        ) -> Poll<Result<usize, io::Error>> {
1077            match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
1078                Poll::Ready(Ok(nwritten)) => {
1079                    log::trace!(
1080                        "{:08x} write (vectored): {:?}",
1081                        self.id,
1082                        Vectored { bufs, nwritten }
1083                    );
1084                    Poll::Ready(Ok(nwritten))
1085                }
1086                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1087                Poll::Pending => Poll::Pending,
1088            }
1089        }
1090
1091        fn is_write_vectored(&self) -> bool {
1092            self.inner.is_write_vectored()
1093        }
1094
1095        fn poll_flush(
1096            mut self: Pin<&mut Self>,
1097            cx: &mut Context,
1098        ) -> Poll<Result<(), std::io::Error>> {
1099            Pin::new(&mut self.inner).poll_flush(cx)
1100        }
1101
1102        fn poll_shutdown(
1103            mut self: Pin<&mut Self>,
1104            cx: &mut Context,
1105        ) -> Poll<Result<(), std::io::Error>> {
1106            Pin::new(&mut self.inner).poll_shutdown(cx)
1107        }
1108    }
1109
1110    #[cfg(feature = "__tls")]
1111    impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> {
1112        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1113            self.inner.tls_info()
1114        }
1115    }
1116
1117    struct Escape<'a>(&'a [u8]);
1118
1119    impl fmt::Debug for Escape<'_> {
1120        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1121            write!(f, "b\"")?;
1122            for &c in self.0 {
1123                // https://doc.rust-lang.org/reference.html#byte-escapes
1124                if c == b'\n' {
1125                    write!(f, "\\n")?;
1126                } else if c == b'\r' {
1127                    write!(f, "\\r")?;
1128                } else if c == b'\t' {
1129                    write!(f, "\\t")?;
1130                } else if c == b'\\' || c == b'"' {
1131                    write!(f, "\\{}", c as char)?;
1132                } else if c == b'\0' {
1133                    write!(f, "\\0")?;
1134                // ASCII printable
1135                } else if c >= 0x20 && c < 0x7f {
1136                    write!(f, "{}", c as char)?;
1137                } else {
1138                    write!(f, "\\x{c:02x}")?;
1139                }
1140            }
1141            write!(f, "\"")?;
1142            Ok(())
1143        }
1144    }
1145
1146    struct Vectored<'a, 'b> {
1147        bufs: &'a [IoSlice<'b>],
1148        nwritten: usize,
1149    }
1150
1151    impl fmt::Debug for Vectored<'_, '_> {
1152        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1153            let mut left = self.nwritten;
1154            for buf in self.bufs.iter() {
1155                if left == 0 {
1156                    break;
1157                }
1158                let n = min(left, buf.len());
1159                Escape(&buf[..n]).fmt(f)?;
1160                left -= n;
1161            }
1162            Ok(())
1163        }
1164    }
1165}
1166
1167#[cfg(feature = "__tls")]
1168#[cfg(test)]
1169mod tests {
1170    use super::tunnel;
1171    use crate::proxy;
1172    use std::io::{Read, Write};
1173    use std::net::TcpListener;
1174    use std::thread;
1175    use tokio::net::TcpStream;
1176    use tokio::runtime;
1177
1178    static TUNNEL_UA: &str = "tunnel-test/x.y";
1179    static TUNNEL_OK: &[u8] = b"\
1180        HTTP/1.1 200 OK\r\n\
1181        \r\n\
1182    ";
1183
1184    macro_rules! mock_tunnel {
1185        () => {{
1186            mock_tunnel!(TUNNEL_OK)
1187        }};
1188        ($write:expr) => {{
1189            mock_tunnel!($write, "")
1190        }};
1191        ($write:expr, $auth:expr) => {{
1192            let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1193            let addr = listener.local_addr().unwrap();
1194            let connect_expected = format!(
1195                "\
1196                 CONNECT {0}:{1} HTTP/1.1\r\n\
1197                 Host: {0}:{1}\r\n\
1198                 User-Agent: {2}\r\n\
1199                 {3}\
1200                 \r\n\
1201                 ",
1202                addr.ip(),
1203                addr.port(),
1204                TUNNEL_UA,
1205                $auth
1206            )
1207            .into_bytes();
1208
1209            thread::spawn(move || {
1210                let (mut sock, _) = listener.accept().unwrap();
1211                let mut buf = [0u8; 4096];
1212                let n = sock.read(&mut buf).unwrap();
1213                assert_eq!(&buf[..n], &connect_expected[..]);
1214
1215                sock.write_all($write).unwrap();
1216            });
1217            addr
1218        }};
1219    }
1220
1221    fn ua() -> Option<http::header::HeaderValue> {
1222        Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1223    }
1224
1225    #[test]
1226    fn test_tunnel() {
1227        let addr = mock_tunnel!();
1228
1229        let rt = runtime::Builder::new_current_thread()
1230            .enable_all()
1231            .build()
1232            .expect("new rt");
1233        let f = async move {
1234            let tcp = TcpStream::connect(&addr).await?;
1235            let host = addr.ip().to_string();
1236            let port = addr.port();
1237            tunnel(tcp, host, port, ua(), None).await
1238        };
1239
1240        rt.block_on(f).unwrap();
1241    }
1242
1243    #[test]
1244    fn test_tunnel_eof() {
1245        let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1246
1247        let rt = runtime::Builder::new_current_thread()
1248            .enable_all()
1249            .build()
1250            .expect("new rt");
1251        let f = async move {
1252            let tcp = TcpStream::connect(&addr).await?;
1253            let host = addr.ip().to_string();
1254            let port = addr.port();
1255            tunnel(tcp, host, port, ua(), None).await
1256        };
1257
1258        rt.block_on(f).unwrap_err();
1259    }
1260
1261    #[test]
1262    fn test_tunnel_non_http_response() {
1263        let addr = mock_tunnel!(b"foo bar baz hallo");
1264
1265        let rt = runtime::Builder::new_current_thread()
1266            .enable_all()
1267            .build()
1268            .expect("new rt");
1269        let f = async move {
1270            let tcp = TcpStream::connect(&addr).await?;
1271            let host = addr.ip().to_string();
1272            let port = addr.port();
1273            tunnel(tcp, host, port, ua(), None).await
1274        };
1275
1276        rt.block_on(f).unwrap_err();
1277    }
1278
1279    #[test]
1280    fn test_tunnel_proxy_unauthorized() {
1281        let addr = mock_tunnel!(
1282            b"\
1283            HTTP/1.1 407 Proxy Authentication Required\r\n\
1284            Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1285            \r\n\
1286        "
1287        );
1288
1289        let rt = runtime::Builder::new_current_thread()
1290            .enable_all()
1291            .build()
1292            .expect("new rt");
1293        let f = async move {
1294            let tcp = TcpStream::connect(&addr).await?;
1295            let host = addr.ip().to_string();
1296            let port = addr.port();
1297            tunnel(tcp, host, port, ua(), None).await
1298        };
1299
1300        let error = rt.block_on(f).unwrap_err();
1301        assert_eq!(error.to_string(), "proxy authentication required");
1302    }
1303
1304    #[test]
1305    fn test_tunnel_basic_auth() {
1306        let addr = mock_tunnel!(
1307            TUNNEL_OK,
1308            "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1309        );
1310
1311        let rt = runtime::Builder::new_current_thread()
1312            .enable_all()
1313            .build()
1314            .expect("new rt");
1315        let f = async move {
1316            let tcp = TcpStream::connect(&addr).await?;
1317            let host = addr.ip().to_string();
1318            let port = addr.port();
1319            tunnel(
1320                tcp,
1321                host,
1322                port,
1323                ua(),
1324                Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1325            )
1326            .await
1327        };
1328
1329        rt.block_on(f).unwrap();
1330    }
1331}