hyper_tls/
client.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use hyper::{client::connect::HttpConnector, service::Service, Uri};
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_native_tls::TlsConnector;
9
10use crate::stream::MaybeHttpsStream;
11
12type BoxError = Box<dyn std::error::Error + Send + Sync>;
13
14/// A Connector for the `https` scheme.
15#[derive(Clone)]
16pub struct HttpsConnector<T> {
17    force_https: bool,
18    http: T,
19    tls: TlsConnector,
20}
21
22impl HttpsConnector<HttpConnector> {
23    /// Construct a new HttpsConnector.
24    ///
25    /// This uses hyper's default `HttpConnector`, and default `TlsConnector`.
26    /// If you wish to use something besides the defaults, use `From::from`.
27    ///
28    /// # Note
29    ///
30    /// By default this connector will use plain HTTP if the URL provded uses
31    /// the HTTP scheme (eg: http://example.com/).
32    ///
33    /// If you would like to force the use of HTTPS then call https_only(true)
34    /// on the returned connector.
35    ///
36    /// # Panics
37    ///
38    /// This will panic if the underlying TLS context could not be created.
39    ///
40    /// To handle that error yourself, you can use the `HttpsConnector::from`
41    /// constructor after trying to make a `TlsConnector`.
42    pub fn new() -> Self {
43        native_tls::TlsConnector::new()
44            .map(|tls| HttpsConnector::new_(tls.into()))
45            .unwrap_or_else(|e| panic!("HttpsConnector::new() failure: {}", e))
46    }
47
48    fn new_(tls: TlsConnector) -> Self {
49        let mut http = HttpConnector::new();
50        http.enforce_http(false);
51        HttpsConnector::from((http, tls))
52    }
53}
54
55impl<T: Default> Default for HttpsConnector<T> {
56    fn default() -> Self {
57        Self::new_with_connector(Default::default())
58    }
59}
60
61impl<T> HttpsConnector<T> {
62    /// Force the use of HTTPS when connecting.
63    ///
64    /// If a URL is not `https` when connecting, an error is returned.
65    pub fn https_only(&mut self, enable: bool) {
66        self.force_https = enable;
67    }
68
69    /// With connector constructor
70    ///
71    pub fn new_with_connector(http: T) -> Self {
72        native_tls::TlsConnector::new()
73            .map(|tls| HttpsConnector::from((http, tls.into())))
74            .unwrap_or_else(|e| {
75                panic!(
76                    "HttpsConnector::new_with_connector(<connector>) failure: {}",
77                    e
78                )
79            })
80    }
81}
82
83impl<T> From<(T, TlsConnector)> for HttpsConnector<T> {
84    fn from(args: (T, TlsConnector)) -> HttpsConnector<T> {
85        HttpsConnector {
86            force_https: false,
87            http: args.0,
88            tls: args.1,
89        }
90    }
91}
92
93impl<T: fmt::Debug> fmt::Debug for HttpsConnector<T> {
94    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
95        f.debug_struct("HttpsConnector")
96            .field("force_https", &self.force_https)
97            .field("http", &self.http)
98            .finish()
99    }
100}
101
102impl<T> Service<Uri> for HttpsConnector<T>
103where
104    T: Service<Uri>,
105    T::Response: AsyncRead + AsyncWrite + Send + Unpin,
106    T::Future: Send + 'static,
107    T::Error: Into<BoxError>,
108{
109    type Response = MaybeHttpsStream<T::Response>;
110    type Error = BoxError;
111    type Future = HttpsConnecting<T::Response>;
112
113    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114        match self.http.poll_ready(cx) {
115            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
116            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
117            Poll::Pending => Poll::Pending,
118        }
119    }
120
121    fn call(&mut self, dst: Uri) -> Self::Future {
122        let is_https = dst.scheme_str() == Some("https");
123        // Early abort if HTTPS is forced but can't be used
124        if !is_https && self.force_https {
125            return err(ForceHttpsButUriNotHttps.into());
126        }
127
128        let host = dst
129            .host()
130            .unwrap_or("")
131            .trim_matches(|c| c == '[' || c == ']')
132            .to_owned();
133        let connecting = self.http.call(dst);
134        let tls = self.tls.clone();
135        let fut = async move {
136            let tcp = connecting.await.map_err(Into::into)?;
137            let maybe = if is_https {
138                let tls = tls.connect(&host, tcp).await?;
139                MaybeHttpsStream::Https(tls)
140            } else {
141                MaybeHttpsStream::Http(tcp)
142            };
143            Ok(maybe)
144        };
145        HttpsConnecting(Box::pin(fut))
146    }
147}
148
149fn err<T>(e: BoxError) -> HttpsConnecting<T> {
150    HttpsConnecting(Box::pin(async { Err(e) }))
151}
152
153type BoxedFut<T> = Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T>, BoxError>> + Send>>;
154
155/// A Future representing work to connect to a URL, and a TLS handshake.
156pub struct HttpsConnecting<T>(BoxedFut<T>);
157
158impl<T: AsyncRead + AsyncWrite + Unpin> Future for HttpsConnecting<T> {
159    type Output = Result<MaybeHttpsStream<T>, BoxError>;
160
161    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162        Pin::new(&mut self.0).poll(cx)
163    }
164}
165
166impl<T> fmt::Debug for HttpsConnecting<T> {
167    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
168        f.pad("HttpsConnecting")
169    }
170}
171
172// ===== Custom Errors =====
173
174#[derive(Debug)]
175struct ForceHttpsButUriNotHttps;
176
177impl fmt::Display for ForceHttpsButUriNotHttps {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        f.write_str("https required but URI was not https")
180    }
181}
182
183impl std::error::Error for ForceHttpsButUriNotHttps {}