native_tls/imp/
openssl.rs

1use openssl::error::ErrorStack;
2use openssl::hash::MessageDigest;
3use openssl::nid::Nid;
4use openssl::pkcs12::Pkcs12;
5use openssl::pkey::{PKey, Private};
6use openssl::ssl::{
7    self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
8    SslVerifyMode,
9};
10use openssl::x509::store::X509StoreBuilder;
11use openssl::x509::{X509VerifyResult, X509};
12use openssl_probe::ProbeResult;
13use std::sync::LazyLock;
14use std::{error, fmt, io};
15
16use crate::{Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
17use log::debug;
18
19static PROBE_RESULT: LazyLock<ProbeResult> = LazyLock::new(openssl_probe::probe);
20
21#[cfg(have_min_max_version)]
22fn supported_protocols(
23    min: Option<Protocol>,
24    max: Option<Protocol>,
25    ctx: &mut SslContextBuilder,
26) -> Result<(), ErrorStack> {
27    use openssl::ssl::SslVersion;
28
29    fn cvt(p: Protocol) -> SslVersion {
30        match p {
31            Protocol::Sslv3 => SslVersion::SSL3,
32            Protocol::Tlsv10 => SslVersion::TLS1,
33            Protocol::Tlsv11 => SslVersion::TLS1_1,
34            Protocol::Tlsv12 => SslVersion::TLS1_2,
35            Protocol::Tlsv13 => SslVersion::TLS1_3,
36        }
37    }
38
39    ctx.set_min_proto_version(min.map(cvt))?;
40    ctx.set_max_proto_version(max.map(cvt))?;
41
42    Ok(())
43}
44
45#[cfg(not(have_min_max_version))]
46fn supported_protocols(
47    min: Option<Protocol>,
48    max: Option<Protocol>,
49    ctx: &mut SslContextBuilder,
50) -> Result<(), ErrorStack> {
51    use openssl::ssl::SslOptions;
52
53    let no_ssl_mask = SslOptions::NO_SSLV2
54        | SslOptions::NO_SSLV3
55        | SslOptions::NO_TLSV1
56        | SslOptions::NO_TLSV1_1
57        | SslOptions::NO_TLSV1_2;
58
59    ctx.clear_options(no_ssl_mask);
60    let mut options = SslOptions::empty();
61    options |= match min {
62        None => SslOptions::empty(),
63        Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
64        Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
65        Some(Protocol::Tlsv11) => {
66            SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
67        }
68        Some(Protocol::Tlsv12) => {
69            SslOptions::NO_SSLV2
70                | SslOptions::NO_SSLV3
71                | SslOptions::NO_TLSV1
72                | SslOptions::NO_TLSV1_1
73        }
74        Some(Protocol::Tlsv13) => {
75            SslOptions::NO_SSLV2
76                | SslOptions::NO_SSLV3
77                | SslOptions::NO_TLSV1
78                | SslOptions::NO_TLSV1_1
79                | SslOptions::NO_TLSV1_2
80        }
81    };
82    options |= match max {
83        // NO_TLSV1_3 may be unavailalbe in the old versions
84        None | Some(Protocol::Tlsv12 | Protocol::Tlsv13) => SslOptions::empty(),
85        Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
86        Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
87        Some(Protocol::Sslv3) => {
88            SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
89        }
90    };
91
92    ctx.set_options(options);
93
94    Ok(())
95}
96
97#[cfg(target_os = "android")]
98fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
99    use std::fs;
100
101    if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
102        let certs = dir
103            .filter_map(|r| r.ok())
104            .filter_map(|e| fs::read(e.path()).ok())
105            .filter_map(|b| X509::from_pem(&b).ok());
106        for cert in certs {
107            if let Err(err) = connector.cert_store_mut().add_cert(cert) {
108                debug!("load_android_root_certs error: {:?}", err);
109            }
110        }
111    }
112
113    Ok(())
114}
115
116#[derive(Debug)]
117pub enum Error {
118    Normal(ErrorStack),
119    Ssl(ssl::Error, X509VerifyResult),
120    EmptyChain,
121    NotPkcs8,
122    AlpnTooLong,
123}
124
125impl error::Error for Error {
126    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
127        match *self {
128            Error::Normal(ref e) => error::Error::source(e),
129            Error::Ssl(ref e, _) => error::Error::source(e),
130            Error::EmptyChain => None,
131            Error::NotPkcs8 => None,
132            Error::AlpnTooLong => None,
133        }
134    }
135}
136
137impl fmt::Display for Error {
138    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
139        match *self {
140            Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
141            Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
142            Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
143            Error::EmptyChain => write!(
144                fmt,
145                "at least one certificate must be provided to create an identity"
146            ),
147            Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
148            Error::AlpnTooLong => write!(fmt, "ALPN too long"),
149        }
150    }
151}
152
153impl From<ErrorStack> for Error {
154    fn from(err: ErrorStack) -> Error {
155        Error::Normal(err)
156    }
157}
158
159#[derive(Clone)]
160pub struct Identity {
161    pkey: PKey<Private>,
162    cert: X509,
163    chain: Vec<X509>,
164}
165
166impl Identity {
167    pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
168        let pkcs12 = Pkcs12::from_der(buf)?;
169        let parsed = pkcs12.parse2(pass)?;
170        Ok(Identity {
171            pkey: parsed.pkey.ok_or_else(|| Error::EmptyChain)?,
172            cert: parsed.cert.ok_or_else(|| Error::EmptyChain)?,
173            // > The stack is the reverse of what you might expect due to the way
174            // > PKCS12_parse is implemented, so we need to load it backwards.
175            // > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44
176            chain: parsed.ca.into_iter().flatten().rev().collect(),
177        })
178    }
179
180    pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
181        if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
182            return Err(Error::NotPkcs8);
183        }
184
185        let pkey = PKey::private_key_from_pem(key)?;
186        let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
187        let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
188        let chain = cert_chain.collect();
189        Ok(Identity { pkey, cert, chain })
190    }
191}
192
193#[derive(Clone)]
194pub struct Certificate(X509);
195
196impl Certificate {
197    pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
198        let cert = X509::from_der(buf)?;
199        Ok(Certificate(cert))
200    }
201
202    pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
203        let cert = X509::from_pem(buf)?;
204        Ok(Certificate(cert))
205    }
206
207    pub fn stack_from_pem(buf: &[u8]) -> Result<Vec<Certificate>, Error> {
208        let certs = X509::stack_from_pem(buf)?;
209        Ok(certs.into_iter().map(Certificate).collect())
210    }
211
212    pub fn to_der(&self) -> Result<Vec<u8>, Error> {
213        let der = self.0.to_der()?;
214        Ok(der)
215    }
216}
217
218pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
219
220impl<S> fmt::Debug for MidHandshakeTlsStream<S>
221where
222    S: fmt::Debug,
223{
224    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
225        fmt::Debug::fmt(&self.0, fmt)
226    }
227}
228
229impl<S> MidHandshakeTlsStream<S> {
230    pub fn get_ref(&self) -> &S {
231        self.0.get_ref()
232    }
233
234    pub fn get_mut(&mut self) -> &mut S {
235        self.0.get_mut()
236    }
237}
238
239impl<S> MidHandshakeTlsStream<S>
240where
241    S: io::Read + io::Write,
242{
243    pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
244        match self.0.handshake() {
245            Ok(s) => Ok(TlsStream(s)),
246            Err(e) => Err(e.into()),
247        }
248    }
249}
250
251pub enum HandshakeError<S> {
252    Failure(Error),
253    WouldBlock(MidHandshakeTlsStream<S>),
254}
255
256impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
257    fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
258        match e {
259            ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
260            ssl::HandshakeError::Failure(e) => {
261                let v = e.ssl().verify_result();
262                HandshakeError::Failure(Error::Ssl(e.into_error(), v))
263            }
264            ssl::HandshakeError::WouldBlock(s) => {
265                HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
266            }
267        }
268    }
269}
270
271impl<S> From<ErrorStack> for HandshakeError<S> {
272    fn from(e: ErrorStack) -> HandshakeError<S> {
273        HandshakeError::Failure(e.into())
274    }
275}
276
277#[derive(Clone)]
278pub struct TlsConnector {
279    connector: SslConnector,
280    use_sni: bool,
281    accept_invalid_hostnames: bool,
282    accept_invalid_certs: bool,
283}
284
285impl TlsConnector {
286    pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
287        let mut connector = SslConnector::builder(SslMethod::tls())?;
288
289        // We need to load these separately so an error on one doesn't prevent the other from loading.
290        if let Some(cert_file) = &PROBE_RESULT.cert_file {
291            if let Err(e) = connector.load_verify_locations(Some(cert_file), None) {
292                debug!("load_verify_locations cert file error: {:?}", e);
293            }
294        }
295        for cert_dir in &PROBE_RESULT.cert_dir {
296            if let Err(e) = connector.load_verify_locations(None, Some(cert_dir)) {
297                debug!("load_verify_locations cert dir error: {:?}", e);
298            }
299        }
300
301        if let Some(ref identity) = builder.identity {
302            connector.set_certificate(&identity.0.cert)?;
303            connector.set_private_key(&identity.0.pkey)?;
304            for cert in identity.0.chain.iter() {
305                // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
306                // specifies that "When sending a certificate chain, extra chain certificates are
307                // sent in order following the end entity certificate."
308                connector.add_extra_chain_cert(cert.to_owned())?;
309            }
310        }
311        supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
312
313        if builder.disable_built_in_roots {
314            connector.set_cert_store(X509StoreBuilder::new()?.build());
315        }
316
317        for cert in &builder.root_certificates {
318            if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
319                debug!("add_cert error: {:?}", err);
320            }
321        }
322
323        #[cfg(feature = "alpn")]
324        if !builder.alpn.is_empty() {
325            connector.set_alpn_protos(&alpn_wire_format(&builder.alpn)?)?;
326        }
327
328        #[cfg(target_os = "android")]
329        load_android_root_certs(&mut connector)?;
330
331        Ok(TlsConnector {
332            connector: connector.build(),
333            use_sni: builder.use_sni,
334            accept_invalid_hostnames: builder.accept_invalid_hostnames,
335            accept_invalid_certs: builder.accept_invalid_certs,
336        })
337    }
338
339    pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
340    where
341        S: io::Read + io::Write,
342    {
343        let mut ssl = self
344            .connector
345            .configure()?
346            .use_server_name_indication(self.use_sni)
347            .verify_hostname(!self.accept_invalid_hostnames);
348        if self.accept_invalid_certs {
349            ssl.set_verify(SslVerifyMode::NONE);
350        }
351
352        let s = ssl.connect(domain, stream)?;
353        Ok(TlsStream(s))
354    }
355}
356
357#[cfg(any(feature = "alpn", feature = "alpn-accept"))]
358fn alpn_wire_format(alpn_list: &[Box<str>]) -> Result<Vec<u8>, Error> {
359    // Wire format is each alpn preceded by its length as a byte.
360    let mut alpn_wire_format =
361        Vec::with_capacity(alpn_list.iter().map(|s| s.len()).sum::<usize>() + alpn_list.len());
362    for alpn in alpn_list.iter().map(|s| s.as_bytes()) {
363        let len_byte = alpn.len().try_into().map_err(|_| Error::AlpnTooLong)?;
364
365        if alpn_wire_format.capacity() - alpn_wire_format.len() >= 1 {
366            alpn_wire_format.push(len_byte);
367        }
368        if alpn_wire_format.capacity() - alpn_wire_format.len() >= alpn.len() {
369            alpn_wire_format.extend(alpn);
370        }
371    }
372    Ok(alpn_wire_format)
373}
374
375impl fmt::Debug for TlsConnector {
376    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
377        fmt.debug_struct("TlsConnector")
378            // n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted
379            .field("use_sni", &self.use_sni)
380            .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
381            .field("accept_invalid_certs", &self.accept_invalid_certs)
382            .finish()
383    }
384}
385
386#[derive(Clone)]
387pub struct TlsAcceptor(SslAcceptor);
388
389impl TlsAcceptor {
390    pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
391        let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
392        acceptor.set_private_key(&builder.identity.0.pkey)?;
393        acceptor.set_certificate(&builder.identity.0.cert)?;
394        #[cfg(feature = "alpn-accept")]
395        if !builder.accept_alpn.is_empty() {
396            let alpn_wire_format = alpn_wire_format(&builder.accept_alpn)?;
397            acceptor.set_alpn_protos(&alpn_wire_format)?;
398            // set up ALPN selection routine - as select_next_proto
399            acceptor.set_alpn_select_callback(move |_: &mut openssl::ssl::SslRef, client_list: &[u8]| {
400                openssl::ssl::select_next_proto(&alpn_wire_format, client_list).and_then(|selected| {
401                    if selected.is_empty() || selected.len() > client_list.len() {
402                        return None;
403                    }
404                    // return string from the client list to separate it from alpn_wire_format's lifetime
405                    // https://github.com/rust-openssl/rust-openssl/pull/2360#issuecomment-2651522324
406                    client_list.windows(selected.len()).find(|&item| item == selected)
407                })
408                .ok_or(openssl::ssl::AlpnError::NOACK)
409            });
410        }
411        for cert in builder.identity.0.chain.iter() {
412            // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
413            // specifies that "When sending a certificate chain, extra chain certificates are
414            // sent in order following the end entity certificate."
415            acceptor.add_extra_chain_cert(cert.to_owned())?;
416        }
417        supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
418
419        Ok(TlsAcceptor(acceptor.build()))
420    }
421
422    pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
423    where
424        S: io::Read + io::Write,
425    {
426        let s = self.0.accept(stream)?;
427        Ok(TlsStream(s))
428    }
429}
430
431pub struct TlsStream<S>(ssl::SslStream<S>);
432
433impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
434    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
435        fmt::Debug::fmt(&self.0, fmt)
436    }
437}
438
439impl<S> TlsStream<S> {
440    pub fn get_ref(&self) -> &S {
441        self.0.get_ref()
442    }
443
444    pub fn get_mut(&mut self) -> &mut S {
445        self.0.get_mut()
446    }
447}
448
449impl<S: io::Read + io::Write> TlsStream<S> {
450    pub fn buffered_read_size(&self) -> Result<usize, Error> {
451        Ok(self.0.ssl().pending())
452    }
453
454    pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
455        Ok(self.0.ssl().peer_certificate().map(Certificate))
456    }
457
458    #[cfg(feature = "alpn")]
459    pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
460        Ok(self
461            .0
462            .ssl()
463            .selected_alpn_protocol()
464            .map(|alpn| alpn.to_vec()))
465    }
466
467    pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
468        let cert = if self.0.ssl().is_server() {
469            self.0.ssl().certificate().map(|x| x.to_owned())
470        } else {
471            self.0.ssl().peer_certificate()
472        };
473
474        let cert = match cert {
475            Some(cert) => cert,
476            None => return Ok(None),
477        };
478
479        let algo_nid = cert.signature_algorithm().object().nid();
480        let signature_algorithms = match algo_nid.signature_algorithms() {
481            Some(algs) => algs,
482            None => return Ok(None),
483        };
484
485        let md = match signature_algorithms.digest {
486            Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
487            nid => match MessageDigest::from_nid(nid) {
488                Some(md) => md,
489                None => return Ok(None),
490            },
491        };
492
493        let digest = cert.digest(md)?;
494
495        Ok(Some(digest.to_vec()))
496    }
497
498    pub fn shutdown(&mut self) -> io::Result<()> {
499        match self.0.shutdown() {
500            Ok(_) => Ok(()),
501            Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
502            Err(e) => Err(e.into_io_error().unwrap_or_else(io::Error::other)),
503        }
504    }
505}
506
507impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
508    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
509        self.0.read(buf)
510    }
511}
512
513impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
514    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
515        self.0.write(buf)
516    }
517
518    fn flush(&mut self) -> io::Result<()> {
519        self.0.flush()
520    }
521}