actix_http/h1/
encoder.rs

1use std::{
2    cmp,
3    io::{self, Write as _},
4    marker::PhantomData,
5    ptr::copy_nonoverlapping,
6    slice::from_raw_parts_mut,
7};
8
9use bytes::{BufMut, BytesMut};
10
11use crate::{
12    body::BodySize,
13    header::{
14        map::Value, HeaderMap, HeaderName, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
15    },
16    helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version,
17};
18
19const AVERAGE_HEADER_SIZE: usize = 30;
20
21#[derive(Debug)]
22pub(crate) struct MessageEncoder<T: MessageType> {
23    #[allow(dead_code)]
24    pub length: BodySize,
25    pub te: TransferEncoding,
26    _phantom: PhantomData<T>,
27}
28
29impl<T: MessageType> Default for MessageEncoder<T> {
30    fn default() -> Self {
31        MessageEncoder {
32            length: BodySize::None,
33            te: TransferEncoding::empty(),
34            _phantom: PhantomData,
35        }
36    }
37}
38
39pub(crate) trait MessageType: Sized {
40    fn status(&self) -> Option<StatusCode>;
41
42    fn headers(&self) -> &HeaderMap;
43
44    fn extra_headers(&self) -> Option<&HeaderMap>;
45
46    fn camel_case(&self) -> bool {
47        false
48    }
49
50    fn chunked(&self) -> bool;
51
52    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
53
54    fn encode_headers(
55        &mut self,
56        dst: &mut BytesMut,
57        version: Version,
58        mut length: BodySize,
59        conn_type: ConnectionType,
60        config: &ServiceConfig,
61    ) -> io::Result<()> {
62        let chunked = self.chunked();
63        let mut skip_len = length != BodySize::Stream;
64        let camel_case = self.camel_case();
65
66        // Content length
67        if let Some(status) = self.status() {
68            match status {
69                StatusCode::CONTINUE
70                | StatusCode::SWITCHING_PROTOCOLS
71                | StatusCode::PROCESSING
72                | StatusCode::NO_CONTENT => {
73                    // skip content-length and transfer-encoding headers
74                    // see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1
75                    // and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2
76                    skip_len = true;
77                    length = BodySize::None
78                }
79
80                StatusCode::NOT_MODIFIED => {
81                    // 304 responses should never have a body but should retain a manually set
82                    // content-length header
83                    // see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1
84                    skip_len = false;
85                    length = BodySize::None;
86                }
87
88                _ => {}
89            }
90        }
91
92        match length {
93            BodySize::Stream => {
94                if chunked {
95                    skip_len = true;
96                    if camel_case {
97                        dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n")
98                    } else {
99                        dst.put_slice(b"\r\ntransfer-encoding: chunked\r\n")
100                    }
101                } else {
102                    skip_len = false;
103                    dst.put_slice(b"\r\n");
104                }
105            }
106            BodySize::Sized(0) if camel_case => dst.put_slice(b"\r\nContent-Length: 0\r\n"),
107            BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"),
108            BodySize::Sized(len) => helpers::write_content_length(len, dst, camel_case),
109            BodySize::None => dst.put_slice(b"\r\n"),
110        }
111
112        // Connection
113        match conn_type {
114            ConnectionType::Upgrade => {
115                if camel_case {
116                    dst.put_slice(b"Connection: Upgrade\r\n")
117                } else {
118                    dst.put_slice(b"connection: upgrade\r\n")
119                }
120            }
121            ConnectionType::KeepAlive if version < Version::HTTP_11 => {
122                if camel_case {
123                    dst.put_slice(b"Connection: keep-alive\r\n")
124                } else {
125                    dst.put_slice(b"connection: keep-alive\r\n")
126                }
127            }
128            ConnectionType::Close if version >= Version::HTTP_11 => {
129                if camel_case {
130                    dst.put_slice(b"Connection: close\r\n")
131                } else {
132                    dst.put_slice(b"connection: close\r\n")
133                }
134            }
135            _ => {}
136        }
137
138        // write headers
139
140        let mut has_date = false;
141
142        let mut buf = dst.chunk_mut().as_mut_ptr();
143        let mut remaining = dst.capacity() - dst.len();
144
145        // tracks bytes written since last buffer resize
146        // since buf is a raw pointer to a bytes container storage but is written to without the
147        // container's knowledge, this is used to sync the containers cursor after data is written
148        let mut pos = 0;
149
150        self.write_headers(|key, value| {
151            match *key {
152                CONNECTION => return,
153                TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
154                DATE => has_date = true,
155                _ => {}
156            }
157
158            let k = key.as_str().as_bytes();
159            let k_len = k.len();
160
161            for val in value.iter() {
162                let v = val.as_ref();
163                let v_len = v.len();
164
165                // key length + value length + colon + space + \r\n
166                let len = k_len + v_len + 4;
167
168                if len > remaining {
169                    // SAFETY: all the bytes written up to position "pos" are initialized
170                    // the written byte count and pointer advancement are kept in sync
171                    unsafe {
172                        dst.advance_mut(pos);
173                    }
174
175                    pos = 0;
176                    dst.reserve(len * 2);
177                    remaining = dst.capacity() - dst.len();
178
179                    // re-assign buf raw pointer since it's possible that the buffer was
180                    // reallocated and/or resized
181                    buf = dst.chunk_mut().as_mut_ptr();
182                }
183
184                // SAFETY: on each write, it is enough to ensure that the advancement of
185                // the cursor matches the number of bytes written
186                unsafe {
187                    if camel_case {
188                        // use Camel-Case headers
189                        write_camel_case(k, buf, k_len);
190                    } else {
191                        write_data(k, buf, k_len);
192                    }
193
194                    buf = buf.add(k_len);
195
196                    write_data(b": ", buf, 2);
197                    buf = buf.add(2);
198
199                    write_data(v, buf, v_len);
200                    buf = buf.add(v_len);
201
202                    write_data(b"\r\n", buf, 2);
203                    buf = buf.add(2);
204                };
205
206                pos += len;
207                remaining -= len;
208            }
209        });
210
211        // final cursor synchronization with the bytes container
212        //
213        // SAFETY: all the bytes written up to position "pos" are initialized
214        // the written byte count and pointer advancement are kept in sync
215        unsafe {
216            dst.advance_mut(pos);
217        }
218
219        if !has_date {
220            // optimized date header, write_date_header writes its own \r\n
221            config.write_date_header(dst, camel_case);
222        }
223
224        // end-of-headers marker
225        dst.extend_from_slice(b"\r\n");
226
227        Ok(())
228    }
229
230    fn write_headers<F>(&mut self, mut f: F)
231    where
232        F: FnMut(&HeaderName, &Value),
233    {
234        match self.extra_headers() {
235            Some(headers) => {
236                // merging headers from head and extra headers.
237                self.headers()
238                    .inner
239                    .iter()
240                    .filter(|(name, _)| !headers.contains_key(*name))
241                    .chain(headers.inner.iter())
242                    .for_each(|(k, v)| f(k, v))
243            }
244            None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)),
245        }
246    }
247}
248
249impl MessageType for Response<()> {
250    fn status(&self) -> Option<StatusCode> {
251        Some(self.head().status)
252    }
253
254    fn chunked(&self) -> bool {
255        self.head().chunked()
256    }
257
258    fn headers(&self) -> &HeaderMap {
259        &self.head().headers
260    }
261
262    fn extra_headers(&self) -> Option<&HeaderMap> {
263        None
264    }
265
266    fn camel_case(&self) -> bool {
267        self.head()
268            .flags
269            .contains(crate::message::Flags::CAMEL_CASE)
270    }
271
272    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
273        let head = self.head();
274        let reason = head.reason().as_bytes();
275        dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
276
277        // status line
278        helpers::write_status_line(head.version, head.status.as_u16(), dst);
279        dst.put_slice(reason);
280        Ok(())
281    }
282}
283
284impl MessageType for RequestHeadType {
285    fn status(&self) -> Option<StatusCode> {
286        None
287    }
288
289    fn chunked(&self) -> bool {
290        self.as_ref().chunked()
291    }
292
293    fn camel_case(&self) -> bool {
294        self.as_ref().camel_case_headers()
295    }
296
297    fn headers(&self) -> &HeaderMap {
298        self.as_ref().headers()
299    }
300
301    fn extra_headers(&self) -> Option<&HeaderMap> {
302        self.extra_headers()
303    }
304
305    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
306        let head = self.as_ref();
307        dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
308        write!(
309            helpers::MutWriter(dst),
310            "{} {} {}",
311            head.method,
312            head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
313            match head.version {
314                Version::HTTP_09 => "HTTP/0.9",
315                Version::HTTP_10 => "HTTP/1.0",
316                Version::HTTP_11 => "HTTP/1.1",
317                Version::HTTP_2 => "HTTP/2.0",
318                Version::HTTP_3 => "HTTP/3.0",
319                _ => return Err(io::Error::other("Unsupported version")),
320            }
321        )
322        .map_err(io::Error::other)
323    }
324}
325
326impl<T: MessageType> MessageEncoder<T> {
327    /// Encode chunk.
328    pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
329        self.te.encode(msg, buf)
330    }
331
332    /// Encode EOF.
333    pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
334        self.te.encode_eof(buf)
335    }
336
337    /// Encode message.
338    pub fn encode(
339        &mut self,
340        dst: &mut BytesMut,
341        message: &mut T,
342        head: bool,
343        stream: bool,
344        version: Version,
345        length: BodySize,
346        conn_type: ConnectionType,
347        config: &ServiceConfig,
348    ) -> io::Result<()> {
349        // transfer encoding
350        if !head {
351            self.te = match length {
352                BodySize::Sized(0) => TransferEncoding::empty(),
353                BodySize::Sized(len) => TransferEncoding::length(len),
354                BodySize::Stream => {
355                    if message.chunked() && !stream {
356                        TransferEncoding::chunked()
357                    } else {
358                        TransferEncoding::eof()
359                    }
360                }
361                BodySize::None => TransferEncoding::empty(),
362            };
363        } else {
364            self.te = TransferEncoding::empty();
365        }
366
367        message.encode_status(dst)?;
368        message.encode_headers(dst, version, length, conn_type, config)
369    }
370}
371
372/// Encoders to handle different Transfer-Encodings.
373#[derive(Debug)]
374pub(crate) struct TransferEncoding {
375    kind: TransferEncodingKind,
376}
377
378#[derive(Debug, PartialEq, Clone)]
379enum TransferEncodingKind {
380    /// An Encoder for when Transfer-Encoding includes `chunked`.
381    Chunked(bool),
382
383    /// An Encoder for when Content-Length is set.
384    ///
385    /// Enforces that the body is not longer than the Content-Length header.
386    Length(u64),
387
388    /// An Encoder for when Content-Length is not known.
389    ///
390    /// Application decides when to stop writing.
391    Eof,
392}
393
394impl TransferEncoding {
395    #[inline]
396    pub fn empty() -> TransferEncoding {
397        TransferEncoding {
398            kind: TransferEncodingKind::Length(0),
399        }
400    }
401
402    #[inline]
403    pub fn eof() -> TransferEncoding {
404        TransferEncoding {
405            kind: TransferEncodingKind::Eof,
406        }
407    }
408
409    #[inline]
410    pub fn chunked() -> TransferEncoding {
411        TransferEncoding {
412            kind: TransferEncodingKind::Chunked(false),
413        }
414    }
415
416    #[inline]
417    pub fn length(len: u64) -> TransferEncoding {
418        TransferEncoding {
419            kind: TransferEncodingKind::Length(len),
420        }
421    }
422
423    /// Encode message. Return `EOF` state of encoder
424    #[inline]
425    pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
426        match self.kind {
427            TransferEncodingKind::Eof => {
428                let eof = msg.is_empty();
429                buf.extend_from_slice(msg);
430                Ok(eof)
431            }
432            TransferEncodingKind::Chunked(ref mut eof) => {
433                if *eof {
434                    return Ok(true);
435                }
436
437                if msg.is_empty() {
438                    *eof = true;
439                    buf.extend_from_slice(b"0\r\n\r\n");
440                } else {
441                    writeln!(helpers::MutWriter(buf), "{:X}\r", msg.len())
442                        .map_err(io::Error::other)?;
443
444                    buf.reserve(msg.len() + 2);
445                    buf.extend_from_slice(msg);
446                    buf.extend_from_slice(b"\r\n");
447                }
448                Ok(*eof)
449            }
450            TransferEncodingKind::Length(ref mut remaining) => {
451                if *remaining > 0 {
452                    if msg.is_empty() {
453                        return Ok(*remaining == 0);
454                    }
455                    let len = cmp::min(*remaining, msg.len() as u64);
456
457                    buf.extend_from_slice(&msg[..len as usize]);
458
459                    *remaining -= len;
460                    Ok(*remaining == 0)
461                } else {
462                    Ok(true)
463                }
464            }
465        }
466    }
467
468    /// Encode eof. Return `EOF` state of encoder
469    #[inline]
470    pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
471        match self.kind {
472            TransferEncodingKind::Eof => Ok(()),
473            TransferEncodingKind::Length(rem) => {
474                if rem != 0 {
475                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
476                } else {
477                    Ok(())
478                }
479            }
480            TransferEncodingKind::Chunked(ref mut eof) => {
481                if !*eof {
482                    *eof = true;
483                    buf.extend_from_slice(b"0\r\n\r\n");
484                }
485                Ok(())
486            }
487        }
488    }
489}
490
491/// # Safety
492/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is
493/// valid for writes of at least `len` bytes.
494unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
495    debug_assert_eq!(value.len(), len);
496    copy_nonoverlapping(value.as_ptr(), buf, len);
497}
498
499/// # Safety
500/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is
501/// valid for writes of at least `len` bytes.
502unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) {
503    // first copy entire (potentially wrong) slice to output
504    write_data(value, buf, len);
505
506    // SAFETY: We just initialized the buffer with `value`
507    let buffer = from_raw_parts_mut(buf, len);
508
509    let mut iter = value.iter();
510
511    // first character should be uppercase
512    if let Some(c @ b'a'..=b'z') = iter.next() {
513        buffer[0] = c & 0b1101_1111;
514    }
515
516    // track 1 ahead of the current position since that's the location being assigned to
517    let mut index = 2;
518
519    // remaining characters after hyphens should also be uppercase
520    while let Some(&c) = iter.next() {
521        if c == b'-' {
522            // advance iter by one and uppercase if needed
523            if let Some(c @ b'a'..=b'z') = iter.next() {
524                buffer[index] = c & 0b1101_1111;
525            }
526            index += 1;
527        }
528
529        index += 1;
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use std::rc::Rc;
536
537    use bytes::Bytes;
538    use http::header::{AUTHORIZATION, UPGRADE_INSECURE_REQUESTS};
539
540    use super::*;
541    use crate::{
542        header::{HeaderValue, CONTENT_TYPE},
543        RequestHead,
544    };
545
546    #[test]
547    fn test_chunked_te() {
548        let mut bytes = BytesMut::new();
549        let mut enc = TransferEncoding::chunked();
550        {
551            assert!(!enc.encode(b"test", &mut bytes).ok().unwrap());
552            assert!(enc.encode(b"", &mut bytes).ok().unwrap());
553        }
554        assert_eq!(
555            bytes.split().freeze(),
556            Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
557        );
558    }
559
560    #[actix_rt::test]
561    async fn test_camel_case() {
562        let mut bytes = BytesMut::with_capacity(2048);
563        let mut head = RequestHead::default();
564        head.set_camel_case_headers(true);
565        head.headers.insert(DATE, HeaderValue::from_static("date"));
566        head.headers
567            .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
568
569        head.headers
570            .insert(UPGRADE_INSECURE_REQUESTS, HeaderValue::from_static("1"));
571
572        let mut head = RequestHeadType::Owned(head);
573
574        let _ = head.encode_headers(
575            &mut bytes,
576            Version::HTTP_11,
577            BodySize::Sized(0),
578            ConnectionType::Close,
579            &ServiceConfig::default(),
580        );
581        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
582
583        assert!(data.contains("Content-Length: 0\r\n"));
584        assert!(data.contains("Connection: close\r\n"));
585        assert!(data.contains("Content-Type: plain/text\r\n"));
586        assert!(data.contains("Date: date\r\n"));
587        assert!(data.contains("Upgrade-Insecure-Requests: 1\r\n"));
588
589        let _ = head.encode_headers(
590            &mut bytes,
591            Version::HTTP_11,
592            BodySize::None,
593            ConnectionType::Upgrade,
594            &ServiceConfig::default(),
595        );
596        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
597        assert!(data.contains("Connection: Upgrade\r\n"));
598
599        let _ = head.encode_headers(
600            &mut bytes,
601            Version::HTTP_11,
602            BodySize::Stream,
603            ConnectionType::KeepAlive,
604            &ServiceConfig::default(),
605        );
606        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
607        assert!(data.contains("Transfer-Encoding: chunked\r\n"));
608        assert!(data.contains("Content-Type: plain/text\r\n"));
609        assert!(data.contains("Date: date\r\n"));
610
611        let mut head = RequestHead::default();
612        head.set_camel_case_headers(false);
613        head.headers.insert(DATE, HeaderValue::from_static("date"));
614        head.headers
615            .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
616        head.headers
617            .append(CONTENT_TYPE, HeaderValue::from_static("xml"));
618
619        let mut head = RequestHeadType::Owned(head);
620        let _ = head.encode_headers(
621            &mut bytes,
622            Version::HTTP_11,
623            BodySize::Stream,
624            ConnectionType::KeepAlive,
625            &ServiceConfig::default(),
626        );
627        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
628        assert!(data.contains("transfer-encoding: chunked\r\n"));
629        assert!(data.contains("content-type: xml\r\n"));
630        assert!(data.contains("content-type: plain/text\r\n"));
631        assert!(data.contains("date: date\r\n"));
632    }
633
634    #[actix_rt::test]
635    async fn test_extra_headers() {
636        let mut bytes = BytesMut::with_capacity(2048);
637
638        let mut head = RequestHead::default();
639        head.headers.insert(
640            AUTHORIZATION,
641            HeaderValue::from_static("some authorization"),
642        );
643
644        let mut extra_headers = HeaderMap::new();
645        extra_headers.insert(
646            AUTHORIZATION,
647            HeaderValue::from_static("another authorization"),
648        );
649        extra_headers.insert(DATE, HeaderValue::from_static("date"));
650
651        let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
652
653        let _ = head.encode_headers(
654            &mut bytes,
655            Version::HTTP_11,
656            BodySize::Sized(0),
657            ConnectionType::Close,
658            &ServiceConfig::default(),
659        );
660        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
661        assert!(data.contains("content-length: 0\r\n"));
662        assert!(data.contains("connection: close\r\n"));
663        assert!(data.contains("authorization: another authorization\r\n"));
664        assert!(data.contains("date: date\r\n"));
665    }
666
667    #[actix_rt::test]
668    async fn test_no_content_length() {
669        let mut bytes = BytesMut::with_capacity(2048);
670
671        let mut res = Response::with_body(StatusCode::SWITCHING_PROTOCOLS, ());
672        res.headers_mut().insert(DATE, HeaderValue::from_static(""));
673        res.headers_mut()
674            .insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
675
676        let _ = res.encode_headers(
677            &mut bytes,
678            Version::HTTP_11,
679            BodySize::Stream,
680            ConnectionType::Upgrade,
681            &ServiceConfig::default(),
682        );
683        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
684        assert!(!data.contains("content-length: 0\r\n"));
685        assert!(!data.contains("transfer-encoding: chunked\r\n"));
686    }
687}