hyper_tls/
stream.rs

1use std::fmt;
2use std::io;
3use std::io::IoSlice;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::client::connect::{Connected, Connection};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9pub use tokio_native_tls::TlsStream;
10
11/// A stream that might be protected with TLS.
12pub enum MaybeHttpsStream<T> {
13    /// A stream over plain text.
14    Http(T),
15    /// A stream protected with TLS.
16    Https(TlsStream<T>),
17}
18
19// ===== impl MaybeHttpsStream =====
20
21impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
22    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23        match self {
24            MaybeHttpsStream::Http(s) => f.debug_tuple("Http").field(s).finish(),
25            MaybeHttpsStream::Https(s) => f.debug_tuple("Https").field(s).finish(),
26        }
27    }
28}
29
30impl<T> From<T> for MaybeHttpsStream<T> {
31    fn from(inner: T) -> Self {
32        MaybeHttpsStream::Http(inner)
33    }
34}
35
36impl<T> From<TlsStream<T>> for MaybeHttpsStream<T> {
37    fn from(inner: TlsStream<T>) -> Self {
38        MaybeHttpsStream::Https(inner)
39    }
40}
41
42impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeHttpsStream<T> {
43    #[inline]
44    fn poll_read(
45        self: Pin<&mut Self>,
46        cx: &mut Context,
47        buf: &mut ReadBuf,
48    ) -> Poll<Result<(), io::Error>> {
49        match Pin::get_mut(self) {
50            MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf),
51            MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf),
52        }
53    }
54}
55
56impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for MaybeHttpsStream<T> {
57    #[inline]
58    fn poll_write(
59        self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61        buf: &[u8],
62    ) -> Poll<Result<usize, io::Error>> {
63        match Pin::get_mut(self) {
64            MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf),
65            MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf),
66        }
67    }
68
69    fn poll_write_vectored(
70        self: Pin<&mut Self>,
71        cx: &mut Context<'_>,
72        bufs: &[IoSlice<'_>],
73    ) -> Poll<Result<usize, io::Error>> {
74        match Pin::get_mut(self) {
75            MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs),
76            MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs),
77        }
78    }
79
80    fn is_write_vectored(&self) -> bool {
81        match self {
82            MaybeHttpsStream::Http(s) => s.is_write_vectored(),
83            MaybeHttpsStream::Https(s) => s.is_write_vectored(),
84        }
85    }
86
87    #[inline]
88    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
89        match Pin::get_mut(self) {
90            MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx),
91            MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx),
92        }
93    }
94
95    #[inline]
96    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
97        match Pin::get_mut(self) {
98            MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx),
99            MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx),
100        }
101    }
102}
103
104impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsStream<T> {
105    fn connected(&self) -> Connected {
106        match self {
107            MaybeHttpsStream::Http(s) => s.connected(),
108            MaybeHttpsStream::Https(s) => s.get_ref().get_ref().get_ref().connected(),
109        }
110    }
111}