1use std::{
2 cmp,
3 io::{self, Write as _},
4 marker::PhantomData,
5 ptr::copy_nonoverlapping,
6 slice::from_raw_parts_mut,
7};
8
9use bytes::{BufMut, BytesMut};
10
11use crate::{
12 body::BodySize,
13 header::{
14 map::Value, HeaderMap, HeaderName, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
15 },
16 helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version,
17};
18
19const AVERAGE_HEADER_SIZE: usize = 30;
20
21#[derive(Debug)]
22pub(crate) struct MessageEncoder<T: MessageType> {
23 #[allow(dead_code)]
24 pub length: BodySize,
25 pub te: TransferEncoding,
26 _phantom: PhantomData<T>,
27}
28
29impl<T: MessageType> Default for MessageEncoder<T> {
30 fn default() -> Self {
31 MessageEncoder {
32 length: BodySize::None,
33 te: TransferEncoding::empty(),
34 _phantom: PhantomData,
35 }
36 }
37}
38
39pub(crate) trait MessageType: Sized {
40 fn status(&self) -> Option<StatusCode>;
41
42 fn headers(&self) -> &HeaderMap;
43
44 fn extra_headers(&self) -> Option<&HeaderMap>;
45
46 fn camel_case(&self) -> bool {
47 false
48 }
49
50 fn chunked(&self) -> bool;
51
52 fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
53
54 fn encode_headers(
55 &mut self,
56 dst: &mut BytesMut,
57 version: Version,
58 mut length: BodySize,
59 conn_type: ConnectionType,
60 config: &ServiceConfig,
61 ) -> io::Result<()> {
62 let chunked = self.chunked();
63 let mut skip_len = length != BodySize::Stream;
64 let camel_case = self.camel_case();
65
66 if let Some(status) = self.status() {
68 match status {
69 StatusCode::CONTINUE
70 | StatusCode::SWITCHING_PROTOCOLS
71 | StatusCode::PROCESSING
72 | StatusCode::NO_CONTENT => {
73 skip_len = true;
77 length = BodySize::None
78 }
79
80 StatusCode::NOT_MODIFIED => {
81 skip_len = false;
85 length = BodySize::None;
86 }
87
88 _ => {}
89 }
90 }
91
92 match length {
93 BodySize::Stream => {
94 if chunked {
95 skip_len = true;
96 if camel_case {
97 dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n")
98 } else {
99 dst.put_slice(b"\r\ntransfer-encoding: chunked\r\n")
100 }
101 } else {
102 skip_len = false;
103 dst.put_slice(b"\r\n");
104 }
105 }
106 BodySize::Sized(0) if camel_case => dst.put_slice(b"\r\nContent-Length: 0\r\n"),
107 BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"),
108 BodySize::Sized(len) => helpers::write_content_length(len, dst, camel_case),
109 BodySize::None => dst.put_slice(b"\r\n"),
110 }
111
112 match conn_type {
114 ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"),
115 ConnectionType::KeepAlive if version < Version::HTTP_11 => {
116 if camel_case {
117 dst.put_slice(b"Connection: keep-alive\r\n")
118 } else {
119 dst.put_slice(b"connection: keep-alive\r\n")
120 }
121 }
122 ConnectionType::Close if version >= Version::HTTP_11 => {
123 if camel_case {
124 dst.put_slice(b"Connection: close\r\n")
125 } else {
126 dst.put_slice(b"connection: close\r\n")
127 }
128 }
129 _ => {}
130 }
131
132 let mut has_date = false;
135
136 let mut buf = dst.chunk_mut().as_mut_ptr();
137 let mut remaining = dst.capacity() - dst.len();
138
139 let mut pos = 0;
143
144 self.write_headers(|key, value| {
145 match *key {
146 CONNECTION => return,
147 TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
148 DATE => has_date = true,
149 _ => {}
150 }
151
152 let k = key.as_str().as_bytes();
153 let k_len = k.len();
154
155 for val in value.iter() {
156 let v = val.as_ref();
157 let v_len = v.len();
158
159 let len = k_len + v_len + 4;
161
162 if len > remaining {
163 unsafe {
166 dst.advance_mut(pos);
167 }
168
169 pos = 0;
170 dst.reserve(len * 2);
171 remaining = dst.capacity() - dst.len();
172
173 buf = dst.chunk_mut().as_mut_ptr();
176 }
177
178 unsafe {
181 if camel_case {
182 write_camel_case(k, buf, k_len);
184 } else {
185 write_data(k, buf, k_len);
186 }
187
188 buf = buf.add(k_len);
189
190 write_data(b": ", buf, 2);
191 buf = buf.add(2);
192
193 write_data(v, buf, v_len);
194 buf = buf.add(v_len);
195
196 write_data(b"\r\n", buf, 2);
197 buf = buf.add(2);
198 };
199
200 pos += len;
201 remaining -= len;
202 }
203 });
204
205 unsafe {
210 dst.advance_mut(pos);
211 }
212
213 if !has_date {
214 config.write_date_header(dst, camel_case);
216 }
217
218 dst.extend_from_slice(b"\r\n");
220
221 Ok(())
222 }
223
224 fn write_headers<F>(&mut self, mut f: F)
225 where
226 F: FnMut(&HeaderName, &Value),
227 {
228 match self.extra_headers() {
229 Some(headers) => {
230 self.headers()
232 .inner
233 .iter()
234 .filter(|(name, _)| !headers.contains_key(*name))
235 .chain(headers.inner.iter())
236 .for_each(|(k, v)| f(k, v))
237 }
238 None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)),
239 }
240 }
241}
242
243impl MessageType for Response<()> {
244 fn status(&self) -> Option<StatusCode> {
245 Some(self.head().status)
246 }
247
248 fn chunked(&self) -> bool {
249 self.head().chunked()
250 }
251
252 fn headers(&self) -> &HeaderMap {
253 &self.head().headers
254 }
255
256 fn extra_headers(&self) -> Option<&HeaderMap> {
257 None
258 }
259
260 fn camel_case(&self) -> bool {
261 self.head()
262 .flags
263 .contains(crate::message::Flags::CAMEL_CASE)
264 }
265
266 fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
267 let head = self.head();
268 let reason = head.reason().as_bytes();
269 dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
270
271 helpers::write_status_line(head.version, head.status.as_u16(), dst);
273 dst.put_slice(reason);
274 Ok(())
275 }
276}
277
278impl MessageType for RequestHeadType {
279 fn status(&self) -> Option<StatusCode> {
280 None
281 }
282
283 fn chunked(&self) -> bool {
284 self.as_ref().chunked()
285 }
286
287 fn camel_case(&self) -> bool {
288 self.as_ref().camel_case_headers()
289 }
290
291 fn headers(&self) -> &HeaderMap {
292 self.as_ref().headers()
293 }
294
295 fn extra_headers(&self) -> Option<&HeaderMap> {
296 self.extra_headers()
297 }
298
299 fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
300 let head = self.as_ref();
301 dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
302 write!(
303 helpers::MutWriter(dst),
304 "{} {} {}",
305 head.method,
306 head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
307 match head.version {
308 Version::HTTP_09 => "HTTP/0.9",
309 Version::HTTP_10 => "HTTP/1.0",
310 Version::HTTP_11 => "HTTP/1.1",
311 Version::HTTP_2 => "HTTP/2.0",
312 Version::HTTP_3 => "HTTP/3.0",
313 _ => return Err(io::Error::other("Unsupported version")),
314 }
315 )
316 .map_err(io::Error::other)
317 }
318}
319
320impl<T: MessageType> MessageEncoder<T> {
321 pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
323 self.te.encode(msg, buf)
324 }
325
326 pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
328 self.te.encode_eof(buf)
329 }
330
331 pub fn encode(
333 &mut self,
334 dst: &mut BytesMut,
335 message: &mut T,
336 head: bool,
337 stream: bool,
338 version: Version,
339 length: BodySize,
340 conn_type: ConnectionType,
341 config: &ServiceConfig,
342 ) -> io::Result<()> {
343 if !head {
345 self.te = match length {
346 BodySize::Sized(0) => TransferEncoding::empty(),
347 BodySize::Sized(len) => TransferEncoding::length(len),
348 BodySize::Stream => {
349 if message.chunked() && !stream {
350 TransferEncoding::chunked()
351 } else {
352 TransferEncoding::eof()
353 }
354 }
355 BodySize::None => TransferEncoding::empty(),
356 };
357 } else {
358 self.te = TransferEncoding::empty();
359 }
360
361 message.encode_status(dst)?;
362 message.encode_headers(dst, version, length, conn_type, config)
363 }
364}
365
366#[derive(Debug)]
368pub(crate) struct TransferEncoding {
369 kind: TransferEncodingKind,
370}
371
372#[derive(Debug, PartialEq, Clone)]
373enum TransferEncodingKind {
374 Chunked(bool),
376
377 Length(u64),
381
382 Eof,
386}
387
388impl TransferEncoding {
389 #[inline]
390 pub fn empty() -> TransferEncoding {
391 TransferEncoding {
392 kind: TransferEncodingKind::Length(0),
393 }
394 }
395
396 #[inline]
397 pub fn eof() -> TransferEncoding {
398 TransferEncoding {
399 kind: TransferEncodingKind::Eof,
400 }
401 }
402
403 #[inline]
404 pub fn chunked() -> TransferEncoding {
405 TransferEncoding {
406 kind: TransferEncodingKind::Chunked(false),
407 }
408 }
409
410 #[inline]
411 pub fn length(len: u64) -> TransferEncoding {
412 TransferEncoding {
413 kind: TransferEncodingKind::Length(len),
414 }
415 }
416
417 #[inline]
419 pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
420 match self.kind {
421 TransferEncodingKind::Eof => {
422 let eof = msg.is_empty();
423 buf.extend_from_slice(msg);
424 Ok(eof)
425 }
426 TransferEncodingKind::Chunked(ref mut eof) => {
427 if *eof {
428 return Ok(true);
429 }
430
431 if msg.is_empty() {
432 *eof = true;
433 buf.extend_from_slice(b"0\r\n\r\n");
434 } else {
435 writeln!(helpers::MutWriter(buf), "{:X}\r", msg.len())
436 .map_err(io::Error::other)?;
437
438 buf.reserve(msg.len() + 2);
439 buf.extend_from_slice(msg);
440 buf.extend_from_slice(b"\r\n");
441 }
442 Ok(*eof)
443 }
444 TransferEncodingKind::Length(ref mut remaining) => {
445 if *remaining > 0 {
446 if msg.is_empty() {
447 return Ok(*remaining == 0);
448 }
449 let len = cmp::min(*remaining, msg.len() as u64);
450
451 buf.extend_from_slice(&msg[..len as usize]);
452
453 *remaining -= len;
454 Ok(*remaining == 0)
455 } else {
456 Ok(true)
457 }
458 }
459 }
460 }
461
462 #[inline]
464 pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
465 match self.kind {
466 TransferEncodingKind::Eof => Ok(()),
467 TransferEncodingKind::Length(rem) => {
468 if rem != 0 {
469 Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
470 } else {
471 Ok(())
472 }
473 }
474 TransferEncodingKind::Chunked(ref mut eof) => {
475 if !*eof {
476 *eof = true;
477 buf.extend_from_slice(b"0\r\n\r\n");
478 }
479 Ok(())
480 }
481 }
482 }
483}
484
485unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
489 debug_assert_eq!(value.len(), len);
490 copy_nonoverlapping(value.as_ptr(), buf, len);
491}
492
493unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) {
497 write_data(value, buf, len);
499
500 let buffer = from_raw_parts_mut(buf, len);
502
503 let mut iter = value.iter();
504
505 if let Some(c @ b'a'..=b'z') = iter.next() {
507 buffer[0] = c & 0b1101_1111;
508 }
509
510 let mut index = 2;
512
513 while let Some(&c) = iter.next() {
515 if c == b'-' {
516 if let Some(c @ b'a'..=b'z') = iter.next() {
518 buffer[index] = c & 0b1101_1111;
519 }
520 index += 1;
521 }
522
523 index += 1;
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use std::rc::Rc;
530
531 use bytes::Bytes;
532 use http::header::{AUTHORIZATION, UPGRADE_INSECURE_REQUESTS};
533
534 use super::*;
535 use crate::{
536 header::{HeaderValue, CONTENT_TYPE},
537 RequestHead,
538 };
539
540 #[test]
541 fn test_chunked_te() {
542 let mut bytes = BytesMut::new();
543 let mut enc = TransferEncoding::chunked();
544 {
545 assert!(!enc.encode(b"test", &mut bytes).ok().unwrap());
546 assert!(enc.encode(b"", &mut bytes).ok().unwrap());
547 }
548 assert_eq!(
549 bytes.split().freeze(),
550 Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
551 );
552 }
553
554 #[actix_rt::test]
555 async fn test_camel_case() {
556 let mut bytes = BytesMut::with_capacity(2048);
557 let mut head = RequestHead::default();
558 head.set_camel_case_headers(true);
559 head.headers.insert(DATE, HeaderValue::from_static("date"));
560 head.headers
561 .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
562
563 head.headers
564 .insert(UPGRADE_INSECURE_REQUESTS, HeaderValue::from_static("1"));
565
566 let mut head = RequestHeadType::Owned(head);
567
568 let _ = head.encode_headers(
569 &mut bytes,
570 Version::HTTP_11,
571 BodySize::Sized(0),
572 ConnectionType::Close,
573 &ServiceConfig::default(),
574 );
575 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
576
577 assert!(data.contains("Content-Length: 0\r\n"));
578 assert!(data.contains("Connection: close\r\n"));
579 assert!(data.contains("Content-Type: plain/text\r\n"));
580 assert!(data.contains("Date: date\r\n"));
581 assert!(data.contains("Upgrade-Insecure-Requests: 1\r\n"));
582
583 let _ = head.encode_headers(
584 &mut bytes,
585 Version::HTTP_11,
586 BodySize::Stream,
587 ConnectionType::KeepAlive,
588 &ServiceConfig::default(),
589 );
590 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
591 assert!(data.contains("Transfer-Encoding: chunked\r\n"));
592 assert!(data.contains("Content-Type: plain/text\r\n"));
593 assert!(data.contains("Date: date\r\n"));
594
595 let mut head = RequestHead::default();
596 head.set_camel_case_headers(false);
597 head.headers.insert(DATE, HeaderValue::from_static("date"));
598 head.headers
599 .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
600 head.headers
601 .append(CONTENT_TYPE, HeaderValue::from_static("xml"));
602
603 let mut head = RequestHeadType::Owned(head);
604 let _ = head.encode_headers(
605 &mut bytes,
606 Version::HTTP_11,
607 BodySize::Stream,
608 ConnectionType::KeepAlive,
609 &ServiceConfig::default(),
610 );
611 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
612 assert!(data.contains("transfer-encoding: chunked\r\n"));
613 assert!(data.contains("content-type: xml\r\n"));
614 assert!(data.contains("content-type: plain/text\r\n"));
615 assert!(data.contains("date: date\r\n"));
616 }
617
618 #[actix_rt::test]
619 async fn test_extra_headers() {
620 let mut bytes = BytesMut::with_capacity(2048);
621
622 let mut head = RequestHead::default();
623 head.headers.insert(
624 AUTHORIZATION,
625 HeaderValue::from_static("some authorization"),
626 );
627
628 let mut extra_headers = HeaderMap::new();
629 extra_headers.insert(
630 AUTHORIZATION,
631 HeaderValue::from_static("another authorization"),
632 );
633 extra_headers.insert(DATE, HeaderValue::from_static("date"));
634
635 let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
636
637 let _ = head.encode_headers(
638 &mut bytes,
639 Version::HTTP_11,
640 BodySize::Sized(0),
641 ConnectionType::Close,
642 &ServiceConfig::default(),
643 );
644 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
645 assert!(data.contains("content-length: 0\r\n"));
646 assert!(data.contains("connection: close\r\n"));
647 assert!(data.contains("authorization: another authorization\r\n"));
648 assert!(data.contains("date: date\r\n"));
649 }
650
651 #[actix_rt::test]
652 async fn test_no_content_length() {
653 let mut bytes = BytesMut::with_capacity(2048);
654
655 let mut res = Response::with_body(StatusCode::SWITCHING_PROTOCOLS, ());
656 res.headers_mut().insert(DATE, HeaderValue::from_static(""));
657 res.headers_mut()
658 .insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
659
660 let _ = res.encode_headers(
661 &mut bytes,
662 Version::HTTP_11,
663 BodySize::Stream,
664 ConnectionType::Upgrade,
665 &ServiceConfig::default(),
666 );
667 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
668 assert!(!data.contains("content-length: 0\r\n"));
669 assert!(!data.contains("transfer-encoding: chunked\r\n"));
670 }
671}