1use openssl::error::ErrorStack;
2use openssl::hash::MessageDigest;
3use openssl::nid::Nid;
4use openssl::pkcs12::Pkcs12;
5use openssl::pkey::{PKey, Private};
6use openssl::ssl::{
7 self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
8 SslVerifyMode,
9};
10use openssl::x509::store::X509StoreBuilder;
11use openssl::x509::{X509VerifyResult, X509};
12use openssl_probe::ProbeResult;
13use std::sync::LazyLock;
14use std::{error, fmt, io};
15
16use crate::{Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
17use log::debug;
18
19static PROBE_RESULT: LazyLock<ProbeResult> = LazyLock::new(openssl_probe::probe);
20
21#[cfg(have_min_max_version)]
22fn supported_protocols(
23 min: Option<Protocol>,
24 max: Option<Protocol>,
25 ctx: &mut SslContextBuilder,
26) -> Result<(), ErrorStack> {
27 use openssl::ssl::SslVersion;
28
29 fn cvt(p: Protocol) -> SslVersion {
30 match p {
31 Protocol::Sslv3 => SslVersion::SSL3,
32 Protocol::Tlsv10 => SslVersion::TLS1,
33 Protocol::Tlsv11 => SslVersion::TLS1_1,
34 Protocol::Tlsv12 => SslVersion::TLS1_2,
35 Protocol::Tlsv13 => SslVersion::TLS1_3,
36 }
37 }
38
39 ctx.set_min_proto_version(min.map(cvt))?;
40 ctx.set_max_proto_version(max.map(cvt))?;
41
42 Ok(())
43}
44
45#[cfg(not(have_min_max_version))]
46fn supported_protocols(
47 min: Option<Protocol>,
48 max: Option<Protocol>,
49 ctx: &mut SslContextBuilder,
50) -> Result<(), ErrorStack> {
51 use openssl::ssl::SslOptions;
52
53 let no_ssl_mask = SslOptions::NO_SSLV2
54 | SslOptions::NO_SSLV3
55 | SslOptions::NO_TLSV1
56 | SslOptions::NO_TLSV1_1
57 | SslOptions::NO_TLSV1_2;
58
59 ctx.clear_options(no_ssl_mask);
60 let mut options = SslOptions::empty();
61 options |= match min {
62 None => SslOptions::empty(),
63 Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
64 Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
65 Some(Protocol::Tlsv11) => {
66 SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
67 }
68 Some(Protocol::Tlsv12) => {
69 SslOptions::NO_SSLV2
70 | SslOptions::NO_SSLV3
71 | SslOptions::NO_TLSV1
72 | SslOptions::NO_TLSV1_1
73 }
74 Some(Protocol::Tlsv13) => {
75 SslOptions::NO_SSLV2
76 | SslOptions::NO_SSLV3
77 | SslOptions::NO_TLSV1
78 | SslOptions::NO_TLSV1_1
79 | SslOptions::NO_TLSV1_2
80 }
81 };
82 options |= match max {
83 None | Some(Protocol::Tlsv12 | Protocol::Tlsv13) => SslOptions::empty(),
85 Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
86 Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
87 Some(Protocol::Sslv3) => {
88 SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
89 }
90 };
91
92 ctx.set_options(options);
93
94 Ok(())
95}
96
97#[cfg(target_os = "android")]
98fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
99 use std::fs;
100
101 if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
102 let certs = dir
103 .filter_map(|r| r.ok())
104 .filter_map(|e| fs::read(e.path()).ok())
105 .filter_map(|b| X509::from_pem(&b).ok());
106 for cert in certs {
107 if let Err(err) = connector.cert_store_mut().add_cert(cert) {
108 debug!("load_android_root_certs error: {:?}", err);
109 }
110 }
111 }
112
113 Ok(())
114}
115
116#[derive(Debug)]
117pub enum Error {
118 Normal(ErrorStack),
119 Ssl(ssl::Error, X509VerifyResult),
120 EmptyChain,
121 NotPkcs8,
122 AlpnTooLong,
123}
124
125impl error::Error for Error {
126 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
127 match *self {
128 Error::Normal(ref e) => error::Error::source(e),
129 Error::Ssl(ref e, _) => error::Error::source(e),
130 Error::EmptyChain => None,
131 Error::NotPkcs8 => None,
132 Error::AlpnTooLong => None,
133 }
134 }
135}
136
137impl fmt::Display for Error {
138 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
139 match *self {
140 Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
141 Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
142 Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
143 Error::EmptyChain => write!(
144 fmt,
145 "at least one certificate must be provided to create an identity"
146 ),
147 Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
148 Error::AlpnTooLong => write!(fmt, "ALPN too long"),
149 }
150 }
151}
152
153impl From<ErrorStack> for Error {
154 fn from(err: ErrorStack) -> Error {
155 Error::Normal(err)
156 }
157}
158
159#[derive(Clone)]
160pub struct Identity {
161 pkey: PKey<Private>,
162 cert: X509,
163 chain: Vec<X509>,
164}
165
166impl Identity {
167 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
168 let pkcs12 = Pkcs12::from_der(buf)?;
169 let parsed = pkcs12.parse2(pass)?;
170 Ok(Identity {
171 pkey: parsed.pkey.ok_or_else(|| Error::EmptyChain)?,
172 cert: parsed.cert.ok_or_else(|| Error::EmptyChain)?,
173 chain: parsed.ca.into_iter().flatten().rev().collect(),
177 })
178 }
179
180 pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
181 if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
182 return Err(Error::NotPkcs8);
183 }
184
185 let pkey = PKey::private_key_from_pem(key)?;
186 let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
187 let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
188 let chain = cert_chain.collect();
189 Ok(Identity { pkey, cert, chain })
190 }
191}
192
193#[derive(Clone)]
194pub struct Certificate(X509);
195
196impl Certificate {
197 pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
198 let cert = X509::from_der(buf)?;
199 Ok(Certificate(cert))
200 }
201
202 pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
203 let cert = X509::from_pem(buf)?;
204 Ok(Certificate(cert))
205 }
206
207 pub fn stack_from_pem(buf: &[u8]) -> Result<Vec<Certificate>, Error> {
208 let certs = X509::stack_from_pem(buf)?;
209 Ok(certs.into_iter().map(Certificate).collect())
210 }
211
212 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
213 let der = self.0.to_der()?;
214 Ok(der)
215 }
216}
217
218pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
219
220impl<S> fmt::Debug for MidHandshakeTlsStream<S>
221where
222 S: fmt::Debug,
223{
224 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
225 fmt::Debug::fmt(&self.0, fmt)
226 }
227}
228
229impl<S> MidHandshakeTlsStream<S> {
230 pub fn get_ref(&self) -> &S {
231 self.0.get_ref()
232 }
233
234 pub fn get_mut(&mut self) -> &mut S {
235 self.0.get_mut()
236 }
237}
238
239impl<S> MidHandshakeTlsStream<S>
240where
241 S: io::Read + io::Write,
242{
243 pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
244 match self.0.handshake() {
245 Ok(s) => Ok(TlsStream(s)),
246 Err(e) => Err(e.into()),
247 }
248 }
249}
250
251pub enum HandshakeError<S> {
252 Failure(Error),
253 WouldBlock(MidHandshakeTlsStream<S>),
254}
255
256impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
257 fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
258 match e {
259 ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
260 ssl::HandshakeError::Failure(e) => {
261 let v = e.ssl().verify_result();
262 HandshakeError::Failure(Error::Ssl(e.into_error(), v))
263 }
264 ssl::HandshakeError::WouldBlock(s) => {
265 HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
266 }
267 }
268 }
269}
270
271impl<S> From<ErrorStack> for HandshakeError<S> {
272 fn from(e: ErrorStack) -> HandshakeError<S> {
273 HandshakeError::Failure(e.into())
274 }
275}
276
277#[derive(Clone)]
278pub struct TlsConnector {
279 connector: SslConnector,
280 use_sni: bool,
281 accept_invalid_hostnames: bool,
282 accept_invalid_certs: bool,
283}
284
285impl TlsConnector {
286 pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
287 let mut connector = SslConnector::builder(SslMethod::tls())?;
288
289 if let Some(cert_file) = &PROBE_RESULT.cert_file {
291 if let Err(e) = connector.load_verify_locations(Some(cert_file), None) {
292 debug!("load_verify_locations cert file error: {:?}", e);
293 }
294 }
295 for cert_dir in &PROBE_RESULT.cert_dir {
296 if let Err(e) = connector.load_verify_locations(None, Some(cert_dir)) {
297 debug!("load_verify_locations cert dir error: {:?}", e);
298 }
299 }
300
301 if let Some(ref identity) = builder.identity {
302 connector.set_certificate(&identity.0.cert)?;
303 connector.set_private_key(&identity.0.pkey)?;
304 for cert in identity.0.chain.iter() {
305 connector.add_extra_chain_cert(cert.to_owned())?;
309 }
310 }
311 supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
312
313 if builder.disable_built_in_roots {
314 connector.set_cert_store(X509StoreBuilder::new()?.build());
315 }
316
317 for cert in &builder.root_certificates {
318 if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
319 debug!("add_cert error: {:?}", err);
320 }
321 }
322
323 #[cfg(feature = "alpn")]
324 if !builder.alpn.is_empty() {
325 connector.set_alpn_protos(&alpn_wire_format(&builder.alpn)?)?;
326 }
327
328 #[cfg(target_os = "android")]
329 load_android_root_certs(&mut connector)?;
330
331 Ok(TlsConnector {
332 connector: connector.build(),
333 use_sni: builder.use_sni,
334 accept_invalid_hostnames: builder.accept_invalid_hostnames,
335 accept_invalid_certs: builder.accept_invalid_certs,
336 })
337 }
338
339 pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
340 where
341 S: io::Read + io::Write,
342 {
343 let mut ssl = self
344 .connector
345 .configure()?
346 .use_server_name_indication(self.use_sni)
347 .verify_hostname(!self.accept_invalid_hostnames);
348 if self.accept_invalid_certs {
349 ssl.set_verify(SslVerifyMode::NONE);
350 }
351
352 let s = ssl.connect(domain, stream)?;
353 Ok(TlsStream(s))
354 }
355}
356
357#[cfg(any(feature = "alpn", feature = "alpn-accept"))]
358fn alpn_wire_format(alpn_list: &[Box<str>]) -> Result<Vec<u8>, Error> {
359 let mut alpn_wire_format =
361 Vec::with_capacity(alpn_list.iter().map(|s| s.len()).sum::<usize>() + alpn_list.len());
362 for alpn in alpn_list.iter().map(|s| s.as_bytes()) {
363 let len_byte = alpn.len().try_into().map_err(|_| Error::AlpnTooLong)?;
364
365 if alpn_wire_format.capacity() - alpn_wire_format.len() >= 1 {
366 alpn_wire_format.push(len_byte);
367 }
368 if alpn_wire_format.capacity() - alpn_wire_format.len() >= alpn.len() {
369 alpn_wire_format.extend(alpn);
370 }
371 }
372 Ok(alpn_wire_format)
373}
374
375impl fmt::Debug for TlsConnector {
376 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
377 fmt.debug_struct("TlsConnector")
378 .field("use_sni", &self.use_sni)
380 .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
381 .field("accept_invalid_certs", &self.accept_invalid_certs)
382 .finish()
383 }
384}
385
386#[derive(Clone)]
387pub struct TlsAcceptor(SslAcceptor);
388
389impl TlsAcceptor {
390 pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
391 let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
392 acceptor.set_private_key(&builder.identity.0.pkey)?;
393 acceptor.set_certificate(&builder.identity.0.cert)?;
394 #[cfg(feature = "alpn-accept")]
395 if !builder.accept_alpn.is_empty() {
396 let alpn_wire_format = alpn_wire_format(&builder.accept_alpn)?;
397 acceptor.set_alpn_protos(&alpn_wire_format)?;
398 acceptor.set_alpn_select_callback(move |_: &mut openssl::ssl::SslRef, client_list: &[u8]| {
400 openssl::ssl::select_next_proto(&alpn_wire_format, client_list).and_then(|selected| {
401 if selected.is_empty() || selected.len() > client_list.len() {
402 return None;
403 }
404 client_list.windows(selected.len()).find(|&item| item == selected)
407 })
408 .ok_or(openssl::ssl::AlpnError::NOACK)
409 });
410 }
411 for cert in builder.identity.0.chain.iter() {
412 acceptor.add_extra_chain_cert(cert.to_owned())?;
416 }
417 supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
418
419 Ok(TlsAcceptor(acceptor.build()))
420 }
421
422 pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
423 where
424 S: io::Read + io::Write,
425 {
426 let s = self.0.accept(stream)?;
427 Ok(TlsStream(s))
428 }
429}
430
431pub struct TlsStream<S>(ssl::SslStream<S>);
432
433impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
434 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
435 fmt::Debug::fmt(&self.0, fmt)
436 }
437}
438
439impl<S> TlsStream<S> {
440 pub fn get_ref(&self) -> &S {
441 self.0.get_ref()
442 }
443
444 pub fn get_mut(&mut self) -> &mut S {
445 self.0.get_mut()
446 }
447}
448
449impl<S: io::Read + io::Write> TlsStream<S> {
450 pub fn buffered_read_size(&self) -> Result<usize, Error> {
451 Ok(self.0.ssl().pending())
452 }
453
454 pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
455 Ok(self.0.ssl().peer_certificate().map(Certificate))
456 }
457
458 #[cfg(feature = "alpn")]
459 pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
460 Ok(self
461 .0
462 .ssl()
463 .selected_alpn_protocol()
464 .map(|alpn| alpn.to_vec()))
465 }
466
467 pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
468 let cert = if self.0.ssl().is_server() {
469 self.0.ssl().certificate().map(|x| x.to_owned())
470 } else {
471 self.0.ssl().peer_certificate()
472 };
473
474 let cert = match cert {
475 Some(cert) => cert,
476 None => return Ok(None),
477 };
478
479 let algo_nid = cert.signature_algorithm().object().nid();
480 let signature_algorithms = match algo_nid.signature_algorithms() {
481 Some(algs) => algs,
482 None => return Ok(None),
483 };
484
485 let md = match signature_algorithms.digest {
486 Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
487 nid => match MessageDigest::from_nid(nid) {
488 Some(md) => md,
489 None => return Ok(None),
490 },
491 };
492
493 let digest = cert.digest(md)?;
494
495 Ok(Some(digest.to_vec()))
496 }
497
498 pub fn shutdown(&mut self) -> io::Result<()> {
499 match self.0.shutdown() {
500 Ok(_) => Ok(()),
501 Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
502 Err(e) => Err(e.into_io_error().unwrap_or_else(io::Error::other)),
503 }
504 }
505}
506
507impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
508 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
509 self.0.read(buf)
510 }
511}
512
513impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
514 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
515 self.0.write(buf)
516 }
517
518 fn flush(&mut self) -> io::Result<()> {
519 self.0.flush()
520 }
521}