actix_http/h1/
encoder.rs

1use std::{
2    cmp,
3    io::{self, Write as _},
4    marker::PhantomData,
5    ptr::copy_nonoverlapping,
6    slice::from_raw_parts_mut,
7};
8
9use bytes::{BufMut, BytesMut};
10
11use crate::{
12    body::BodySize,
13    header::{
14        map::Value, HeaderMap, HeaderName, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
15    },
16    helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version,
17};
18
19const AVERAGE_HEADER_SIZE: usize = 30;
20
21#[derive(Debug)]
22pub(crate) struct MessageEncoder<T: MessageType> {
23    #[allow(dead_code)]
24    pub length: BodySize,
25    pub te: TransferEncoding,
26    _phantom: PhantomData<T>,
27}
28
29impl<T: MessageType> Default for MessageEncoder<T> {
30    fn default() -> Self {
31        MessageEncoder {
32            length: BodySize::None,
33            te: TransferEncoding::empty(),
34            _phantom: PhantomData,
35        }
36    }
37}
38
39pub(crate) trait MessageType: Sized {
40    fn status(&self) -> Option<StatusCode>;
41
42    fn headers(&self) -> &HeaderMap;
43
44    fn extra_headers(&self) -> Option<&HeaderMap>;
45
46    fn camel_case(&self) -> bool {
47        false
48    }
49
50    fn chunked(&self) -> bool;
51
52    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
53
54    fn encode_headers(
55        &mut self,
56        dst: &mut BytesMut,
57        version: Version,
58        mut length: BodySize,
59        conn_type: ConnectionType,
60        config: &ServiceConfig,
61    ) -> io::Result<()> {
62        let chunked = self.chunked();
63        let mut skip_len = length != BodySize::Stream;
64        let camel_case = self.camel_case();
65
66        // Content length
67        if let Some(status) = self.status() {
68            match status {
69                StatusCode::CONTINUE
70                | StatusCode::SWITCHING_PROTOCOLS
71                | StatusCode::PROCESSING
72                | StatusCode::NO_CONTENT => {
73                    // skip content-length and transfer-encoding headers
74                    // see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1
75                    // and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2
76                    skip_len = true;
77                    length = BodySize::None
78                }
79
80                StatusCode::NOT_MODIFIED => {
81                    // 304 responses should never have a body but should retain a manually set
82                    // content-length header
83                    // see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1
84                    skip_len = false;
85                    length = BodySize::None;
86                }
87
88                _ => {}
89            }
90        }
91
92        match length {
93            BodySize::Stream => {
94                if chunked {
95                    skip_len = true;
96                    if camel_case {
97                        dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n")
98                    } else {
99                        dst.put_slice(b"\r\ntransfer-encoding: chunked\r\n")
100                    }
101                } else {
102                    skip_len = false;
103                    dst.put_slice(b"\r\n");
104                }
105            }
106            BodySize::Sized(0) if camel_case => dst.put_slice(b"\r\nContent-Length: 0\r\n"),
107            BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"),
108            BodySize::Sized(len) => helpers::write_content_length(len, dst, camel_case),
109            BodySize::None => dst.put_slice(b"\r\n"),
110        }
111
112        // Connection
113        match conn_type {
114            ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"),
115            ConnectionType::KeepAlive if version < Version::HTTP_11 => {
116                if camel_case {
117                    dst.put_slice(b"Connection: keep-alive\r\n")
118                } else {
119                    dst.put_slice(b"connection: keep-alive\r\n")
120                }
121            }
122            ConnectionType::Close if version >= Version::HTTP_11 => {
123                if camel_case {
124                    dst.put_slice(b"Connection: close\r\n")
125                } else {
126                    dst.put_slice(b"connection: close\r\n")
127                }
128            }
129            _ => {}
130        }
131
132        // write headers
133
134        let mut has_date = false;
135
136        let mut buf = dst.chunk_mut().as_mut_ptr();
137        let mut remaining = dst.capacity() - dst.len();
138
139        // tracks bytes written since last buffer resize
140        // since buf is a raw pointer to a bytes container storage but is written to without the
141        // container's knowledge, this is used to sync the containers cursor after data is written
142        let mut pos = 0;
143
144        self.write_headers(|key, value| {
145            match *key {
146                CONNECTION => return,
147                TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
148                DATE => has_date = true,
149                _ => {}
150            }
151
152            let k = key.as_str().as_bytes();
153            let k_len = k.len();
154
155            for val in value.iter() {
156                let v = val.as_ref();
157                let v_len = v.len();
158
159                // key length + value length + colon + space + \r\n
160                let len = k_len + v_len + 4;
161
162                if len > remaining {
163                    // SAFETY: all the bytes written up to position "pos" are initialized
164                    // the written byte count and pointer advancement are kept in sync
165                    unsafe {
166                        dst.advance_mut(pos);
167                    }
168
169                    pos = 0;
170                    dst.reserve(len * 2);
171                    remaining = dst.capacity() - dst.len();
172
173                    // re-assign buf raw pointer since it's possible that the buffer was
174                    // reallocated and/or resized
175                    buf = dst.chunk_mut().as_mut_ptr();
176                }
177
178                // SAFETY: on each write, it is enough to ensure that the advancement of
179                // the cursor matches the number of bytes written
180                unsafe {
181                    if camel_case {
182                        // use Camel-Case headers
183                        write_camel_case(k, buf, k_len);
184                    } else {
185                        write_data(k, buf, k_len);
186                    }
187
188                    buf = buf.add(k_len);
189
190                    write_data(b": ", buf, 2);
191                    buf = buf.add(2);
192
193                    write_data(v, buf, v_len);
194                    buf = buf.add(v_len);
195
196                    write_data(b"\r\n", buf, 2);
197                    buf = buf.add(2);
198                };
199
200                pos += len;
201                remaining -= len;
202            }
203        });
204
205        // final cursor synchronization with the bytes container
206        //
207        // SAFETY: all the bytes written up to position "pos" are initialized
208        // the written byte count and pointer advancement are kept in sync
209        unsafe {
210            dst.advance_mut(pos);
211        }
212
213        if !has_date {
214            // optimized date header, write_date_header writes its own \r\n
215            config.write_date_header(dst, camel_case);
216        }
217
218        // end-of-headers marker
219        dst.extend_from_slice(b"\r\n");
220
221        Ok(())
222    }
223
224    fn write_headers<F>(&mut self, mut f: F)
225    where
226        F: FnMut(&HeaderName, &Value),
227    {
228        match self.extra_headers() {
229            Some(headers) => {
230                // merging headers from head and extra headers.
231                self.headers()
232                    .inner
233                    .iter()
234                    .filter(|(name, _)| !headers.contains_key(*name))
235                    .chain(headers.inner.iter())
236                    .for_each(|(k, v)| f(k, v))
237            }
238            None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)),
239        }
240    }
241}
242
243impl MessageType for Response<()> {
244    fn status(&self) -> Option<StatusCode> {
245        Some(self.head().status)
246    }
247
248    fn chunked(&self) -> bool {
249        self.head().chunked()
250    }
251
252    fn headers(&self) -> &HeaderMap {
253        &self.head().headers
254    }
255
256    fn extra_headers(&self) -> Option<&HeaderMap> {
257        None
258    }
259
260    fn camel_case(&self) -> bool {
261        self.head()
262            .flags
263            .contains(crate::message::Flags::CAMEL_CASE)
264    }
265
266    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
267        let head = self.head();
268        let reason = head.reason().as_bytes();
269        dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
270
271        // status line
272        helpers::write_status_line(head.version, head.status.as_u16(), dst);
273        dst.put_slice(reason);
274        Ok(())
275    }
276}
277
278impl MessageType for RequestHeadType {
279    fn status(&self) -> Option<StatusCode> {
280        None
281    }
282
283    fn chunked(&self) -> bool {
284        self.as_ref().chunked()
285    }
286
287    fn camel_case(&self) -> bool {
288        self.as_ref().camel_case_headers()
289    }
290
291    fn headers(&self) -> &HeaderMap {
292        self.as_ref().headers()
293    }
294
295    fn extra_headers(&self) -> Option<&HeaderMap> {
296        self.extra_headers()
297    }
298
299    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
300        let head = self.as_ref();
301        dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
302        write!(
303            helpers::MutWriter(dst),
304            "{} {} {}",
305            head.method,
306            head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
307            match head.version {
308                Version::HTTP_09 => "HTTP/0.9",
309                Version::HTTP_10 => "HTTP/1.0",
310                Version::HTTP_11 => "HTTP/1.1",
311                Version::HTTP_2 => "HTTP/2.0",
312                Version::HTTP_3 => "HTTP/3.0",
313                _ => return Err(io::Error::other("Unsupported version")),
314            }
315        )
316        .map_err(io::Error::other)
317    }
318}
319
320impl<T: MessageType> MessageEncoder<T> {
321    /// Encode chunk.
322    pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
323        self.te.encode(msg, buf)
324    }
325
326    /// Encode EOF.
327    pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
328        self.te.encode_eof(buf)
329    }
330
331    /// Encode message.
332    pub fn encode(
333        &mut self,
334        dst: &mut BytesMut,
335        message: &mut T,
336        head: bool,
337        stream: bool,
338        version: Version,
339        length: BodySize,
340        conn_type: ConnectionType,
341        config: &ServiceConfig,
342    ) -> io::Result<()> {
343        // transfer encoding
344        if !head {
345            self.te = match length {
346                BodySize::Sized(0) => TransferEncoding::empty(),
347                BodySize::Sized(len) => TransferEncoding::length(len),
348                BodySize::Stream => {
349                    if message.chunked() && !stream {
350                        TransferEncoding::chunked()
351                    } else {
352                        TransferEncoding::eof()
353                    }
354                }
355                BodySize::None => TransferEncoding::empty(),
356            };
357        } else {
358            self.te = TransferEncoding::empty();
359        }
360
361        message.encode_status(dst)?;
362        message.encode_headers(dst, version, length, conn_type, config)
363    }
364}
365
366/// Encoders to handle different Transfer-Encodings.
367#[derive(Debug)]
368pub(crate) struct TransferEncoding {
369    kind: TransferEncodingKind,
370}
371
372#[derive(Debug, PartialEq, Clone)]
373enum TransferEncodingKind {
374    /// An Encoder for when Transfer-Encoding includes `chunked`.
375    Chunked(bool),
376
377    /// An Encoder for when Content-Length is set.
378    ///
379    /// Enforces that the body is not longer than the Content-Length header.
380    Length(u64),
381
382    /// An Encoder for when Content-Length is not known.
383    ///
384    /// Application decides when to stop writing.
385    Eof,
386}
387
388impl TransferEncoding {
389    #[inline]
390    pub fn empty() -> TransferEncoding {
391        TransferEncoding {
392            kind: TransferEncodingKind::Length(0),
393        }
394    }
395
396    #[inline]
397    pub fn eof() -> TransferEncoding {
398        TransferEncoding {
399            kind: TransferEncodingKind::Eof,
400        }
401    }
402
403    #[inline]
404    pub fn chunked() -> TransferEncoding {
405        TransferEncoding {
406            kind: TransferEncodingKind::Chunked(false),
407        }
408    }
409
410    #[inline]
411    pub fn length(len: u64) -> TransferEncoding {
412        TransferEncoding {
413            kind: TransferEncodingKind::Length(len),
414        }
415    }
416
417    /// Encode message. Return `EOF` state of encoder
418    #[inline]
419    pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
420        match self.kind {
421            TransferEncodingKind::Eof => {
422                let eof = msg.is_empty();
423                buf.extend_from_slice(msg);
424                Ok(eof)
425            }
426            TransferEncodingKind::Chunked(ref mut eof) => {
427                if *eof {
428                    return Ok(true);
429                }
430
431                if msg.is_empty() {
432                    *eof = true;
433                    buf.extend_from_slice(b"0\r\n\r\n");
434                } else {
435                    writeln!(helpers::MutWriter(buf), "{:X}\r", msg.len())
436                        .map_err(io::Error::other)?;
437
438                    buf.reserve(msg.len() + 2);
439                    buf.extend_from_slice(msg);
440                    buf.extend_from_slice(b"\r\n");
441                }
442                Ok(*eof)
443            }
444            TransferEncodingKind::Length(ref mut remaining) => {
445                if *remaining > 0 {
446                    if msg.is_empty() {
447                        return Ok(*remaining == 0);
448                    }
449                    let len = cmp::min(*remaining, msg.len() as u64);
450
451                    buf.extend_from_slice(&msg[..len as usize]);
452
453                    *remaining -= len;
454                    Ok(*remaining == 0)
455                } else {
456                    Ok(true)
457                }
458            }
459        }
460    }
461
462    /// Encode eof. Return `EOF` state of encoder
463    #[inline]
464    pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
465        match self.kind {
466            TransferEncodingKind::Eof => Ok(()),
467            TransferEncodingKind::Length(rem) => {
468                if rem != 0 {
469                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
470                } else {
471                    Ok(())
472                }
473            }
474            TransferEncodingKind::Chunked(ref mut eof) => {
475                if !*eof {
476                    *eof = true;
477                    buf.extend_from_slice(b"0\r\n\r\n");
478                }
479                Ok(())
480            }
481        }
482    }
483}
484
485/// # Safety
486/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is
487/// valid for writes of at least `len` bytes.
488unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
489    debug_assert_eq!(value.len(), len);
490    copy_nonoverlapping(value.as_ptr(), buf, len);
491}
492
493/// # Safety
494/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is
495/// valid for writes of at least `len` bytes.
496unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) {
497    // first copy entire (potentially wrong) slice to output
498    write_data(value, buf, len);
499
500    // SAFETY: We just initialized the buffer with `value`
501    let buffer = from_raw_parts_mut(buf, len);
502
503    let mut iter = value.iter();
504
505    // first character should be uppercase
506    if let Some(c @ b'a'..=b'z') = iter.next() {
507        buffer[0] = c & 0b1101_1111;
508    }
509
510    // track 1 ahead of the current position since that's the location being assigned to
511    let mut index = 2;
512
513    // remaining characters after hyphens should also be uppercase
514    while let Some(&c) = iter.next() {
515        if c == b'-' {
516            // advance iter by one and uppercase if needed
517            if let Some(c @ b'a'..=b'z') = iter.next() {
518                buffer[index] = c & 0b1101_1111;
519            }
520            index += 1;
521        }
522
523        index += 1;
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use std::rc::Rc;
530
531    use bytes::Bytes;
532    use http::header::{AUTHORIZATION, UPGRADE_INSECURE_REQUESTS};
533
534    use super::*;
535    use crate::{
536        header::{HeaderValue, CONTENT_TYPE},
537        RequestHead,
538    };
539
540    #[test]
541    fn test_chunked_te() {
542        let mut bytes = BytesMut::new();
543        let mut enc = TransferEncoding::chunked();
544        {
545            assert!(!enc.encode(b"test", &mut bytes).ok().unwrap());
546            assert!(enc.encode(b"", &mut bytes).ok().unwrap());
547        }
548        assert_eq!(
549            bytes.split().freeze(),
550            Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
551        );
552    }
553
554    #[actix_rt::test]
555    async fn test_camel_case() {
556        let mut bytes = BytesMut::with_capacity(2048);
557        let mut head = RequestHead::default();
558        head.set_camel_case_headers(true);
559        head.headers.insert(DATE, HeaderValue::from_static("date"));
560        head.headers
561            .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
562
563        head.headers
564            .insert(UPGRADE_INSECURE_REQUESTS, HeaderValue::from_static("1"));
565
566        let mut head = RequestHeadType::Owned(head);
567
568        let _ = head.encode_headers(
569            &mut bytes,
570            Version::HTTP_11,
571            BodySize::Sized(0),
572            ConnectionType::Close,
573            &ServiceConfig::default(),
574        );
575        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
576
577        assert!(data.contains("Content-Length: 0\r\n"));
578        assert!(data.contains("Connection: close\r\n"));
579        assert!(data.contains("Content-Type: plain/text\r\n"));
580        assert!(data.contains("Date: date\r\n"));
581        assert!(data.contains("Upgrade-Insecure-Requests: 1\r\n"));
582
583        let _ = head.encode_headers(
584            &mut bytes,
585            Version::HTTP_11,
586            BodySize::Stream,
587            ConnectionType::KeepAlive,
588            &ServiceConfig::default(),
589        );
590        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
591        assert!(data.contains("Transfer-Encoding: chunked\r\n"));
592        assert!(data.contains("Content-Type: plain/text\r\n"));
593        assert!(data.contains("Date: date\r\n"));
594
595        let mut head = RequestHead::default();
596        head.set_camel_case_headers(false);
597        head.headers.insert(DATE, HeaderValue::from_static("date"));
598        head.headers
599            .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
600        head.headers
601            .append(CONTENT_TYPE, HeaderValue::from_static("xml"));
602
603        let mut head = RequestHeadType::Owned(head);
604        let _ = head.encode_headers(
605            &mut bytes,
606            Version::HTTP_11,
607            BodySize::Stream,
608            ConnectionType::KeepAlive,
609            &ServiceConfig::default(),
610        );
611        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
612        assert!(data.contains("transfer-encoding: chunked\r\n"));
613        assert!(data.contains("content-type: xml\r\n"));
614        assert!(data.contains("content-type: plain/text\r\n"));
615        assert!(data.contains("date: date\r\n"));
616    }
617
618    #[actix_rt::test]
619    async fn test_extra_headers() {
620        let mut bytes = BytesMut::with_capacity(2048);
621
622        let mut head = RequestHead::default();
623        head.headers.insert(
624            AUTHORIZATION,
625            HeaderValue::from_static("some authorization"),
626        );
627
628        let mut extra_headers = HeaderMap::new();
629        extra_headers.insert(
630            AUTHORIZATION,
631            HeaderValue::from_static("another authorization"),
632        );
633        extra_headers.insert(DATE, HeaderValue::from_static("date"));
634
635        let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
636
637        let _ = head.encode_headers(
638            &mut bytes,
639            Version::HTTP_11,
640            BodySize::Sized(0),
641            ConnectionType::Close,
642            &ServiceConfig::default(),
643        );
644        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
645        assert!(data.contains("content-length: 0\r\n"));
646        assert!(data.contains("connection: close\r\n"));
647        assert!(data.contains("authorization: another authorization\r\n"));
648        assert!(data.contains("date: date\r\n"));
649    }
650
651    #[actix_rt::test]
652    async fn test_no_content_length() {
653        let mut bytes = BytesMut::with_capacity(2048);
654
655        let mut res = Response::with_body(StatusCode::SWITCHING_PROTOCOLS, ());
656        res.headers_mut().insert(DATE, HeaderValue::from_static(""));
657        res.headers_mut()
658            .insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
659
660        let _ = res.encode_headers(
661            &mut bytes,
662            Version::HTTP_11,
663            BodySize::Stream,
664            ConnectionType::Upgrade,
665            &ServiceConfig::default(),
666        );
667        let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
668        assert!(!data.contains("content-length: 0\r\n"));
669        assert!(!data.contains("transfer-encoding: chunked\r\n"));
670    }
671}