1use std::{
2 cmp,
3 io::{self, Write as _},
4 marker::PhantomData,
5 ptr::copy_nonoverlapping,
6 slice::from_raw_parts_mut,
7};
8
9use bytes::{BufMut, BytesMut};
10
11use crate::{
12 body::BodySize,
13 header::{
14 map::Value, HeaderMap, HeaderName, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
15 },
16 helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version,
17};
18
19const AVERAGE_HEADER_SIZE: usize = 30;
20
21#[derive(Debug)]
22pub(crate) struct MessageEncoder<T: MessageType> {
23 #[allow(dead_code)]
24 pub length: BodySize,
25 pub te: TransferEncoding,
26 _phantom: PhantomData<T>,
27}
28
29impl<T: MessageType> Default for MessageEncoder<T> {
30 fn default() -> Self {
31 MessageEncoder {
32 length: BodySize::None,
33 te: TransferEncoding::empty(),
34 _phantom: PhantomData,
35 }
36 }
37}
38
39pub(crate) trait MessageType: Sized {
40 fn status(&self) -> Option<StatusCode>;
41
42 fn headers(&self) -> &HeaderMap;
43
44 fn extra_headers(&self) -> Option<&HeaderMap>;
45
46 fn camel_case(&self) -> bool {
47 false
48 }
49
50 fn chunked(&self) -> bool;
51
52 fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
53
54 fn encode_headers(
55 &mut self,
56 dst: &mut BytesMut,
57 version: Version,
58 mut length: BodySize,
59 conn_type: ConnectionType,
60 config: &ServiceConfig,
61 ) -> io::Result<()> {
62 let chunked = self.chunked();
63 let mut skip_len = length != BodySize::Stream;
64 let camel_case = self.camel_case();
65
66 if let Some(status) = self.status() {
68 match status {
69 StatusCode::CONTINUE
70 | StatusCode::SWITCHING_PROTOCOLS
71 | StatusCode::PROCESSING
72 | StatusCode::NO_CONTENT => {
73 skip_len = true;
77 length = BodySize::None
78 }
79
80 StatusCode::NOT_MODIFIED => {
81 skip_len = false;
85 length = BodySize::None;
86 }
87
88 _ => {}
89 }
90 }
91
92 match length {
93 BodySize::Stream => {
94 if chunked {
95 skip_len = true;
96 if camel_case {
97 dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n")
98 } else {
99 dst.put_slice(b"\r\ntransfer-encoding: chunked\r\n")
100 }
101 } else {
102 skip_len = false;
103 dst.put_slice(b"\r\n");
104 }
105 }
106 BodySize::Sized(0) if camel_case => dst.put_slice(b"\r\nContent-Length: 0\r\n"),
107 BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"),
108 BodySize::Sized(len) => helpers::write_content_length(len, dst, camel_case),
109 BodySize::None => dst.put_slice(b"\r\n"),
110 }
111
112 match conn_type {
114 ConnectionType::Upgrade => {
115 if camel_case {
116 dst.put_slice(b"Connection: Upgrade\r\n")
117 } else {
118 dst.put_slice(b"connection: upgrade\r\n")
119 }
120 }
121 ConnectionType::KeepAlive if version < Version::HTTP_11 => {
122 if camel_case {
123 dst.put_slice(b"Connection: keep-alive\r\n")
124 } else {
125 dst.put_slice(b"connection: keep-alive\r\n")
126 }
127 }
128 ConnectionType::Close if version >= Version::HTTP_11 => {
129 if camel_case {
130 dst.put_slice(b"Connection: close\r\n")
131 } else {
132 dst.put_slice(b"connection: close\r\n")
133 }
134 }
135 _ => {}
136 }
137
138 let mut has_date = false;
141
142 let mut buf = dst.chunk_mut().as_mut_ptr();
143 let mut remaining = dst.capacity() - dst.len();
144
145 let mut pos = 0;
149
150 self.write_headers(|key, value| {
151 match *key {
152 CONNECTION => return,
153 TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
154 DATE => has_date = true,
155 _ => {}
156 }
157
158 let k = key.as_str().as_bytes();
159 let k_len = k.len();
160
161 for val in value.iter() {
162 let v = val.as_ref();
163 let v_len = v.len();
164
165 let len = k_len + v_len + 4;
167
168 if len > remaining {
169 unsafe {
172 dst.advance_mut(pos);
173 }
174
175 pos = 0;
176 dst.reserve(len * 2);
177 remaining = dst.capacity() - dst.len();
178
179 buf = dst.chunk_mut().as_mut_ptr();
182 }
183
184 unsafe {
187 if camel_case {
188 write_camel_case(k, buf, k_len);
190 } else {
191 write_data(k, buf, k_len);
192 }
193
194 buf = buf.add(k_len);
195
196 write_data(b": ", buf, 2);
197 buf = buf.add(2);
198
199 write_data(v, buf, v_len);
200 buf = buf.add(v_len);
201
202 write_data(b"\r\n", buf, 2);
203 buf = buf.add(2);
204 };
205
206 pos += len;
207 remaining -= len;
208 }
209 });
210
211 unsafe {
216 dst.advance_mut(pos);
217 }
218
219 if !has_date {
220 config.write_date_header(dst, camel_case);
222 }
223
224 dst.extend_from_slice(b"\r\n");
226
227 Ok(())
228 }
229
230 fn write_headers<F>(&mut self, mut f: F)
231 where
232 F: FnMut(&HeaderName, &Value),
233 {
234 match self.extra_headers() {
235 Some(headers) => {
236 self.headers()
238 .inner
239 .iter()
240 .filter(|(name, _)| !headers.contains_key(*name))
241 .chain(headers.inner.iter())
242 .for_each(|(k, v)| f(k, v))
243 }
244 None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)),
245 }
246 }
247}
248
249impl MessageType for Response<()> {
250 fn status(&self) -> Option<StatusCode> {
251 Some(self.head().status)
252 }
253
254 fn chunked(&self) -> bool {
255 self.head().chunked()
256 }
257
258 fn headers(&self) -> &HeaderMap {
259 &self.head().headers
260 }
261
262 fn extra_headers(&self) -> Option<&HeaderMap> {
263 None
264 }
265
266 fn camel_case(&self) -> bool {
267 self.head()
268 .flags
269 .contains(crate::message::Flags::CAMEL_CASE)
270 }
271
272 fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
273 let head = self.head();
274 let reason = head.reason().as_bytes();
275 dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
276
277 helpers::write_status_line(head.version, head.status.as_u16(), dst);
279 dst.put_slice(reason);
280 Ok(())
281 }
282}
283
284impl MessageType for RequestHeadType {
285 fn status(&self) -> Option<StatusCode> {
286 None
287 }
288
289 fn chunked(&self) -> bool {
290 self.as_ref().chunked()
291 }
292
293 fn camel_case(&self) -> bool {
294 self.as_ref().camel_case_headers()
295 }
296
297 fn headers(&self) -> &HeaderMap {
298 self.as_ref().headers()
299 }
300
301 fn extra_headers(&self) -> Option<&HeaderMap> {
302 self.extra_headers()
303 }
304
305 fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
306 let head = self.as_ref();
307 dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
308 write!(
309 helpers::MutWriter(dst),
310 "{} {} {}",
311 head.method,
312 head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
313 match head.version {
314 Version::HTTP_09 => "HTTP/0.9",
315 Version::HTTP_10 => "HTTP/1.0",
316 Version::HTTP_11 => "HTTP/1.1",
317 Version::HTTP_2 => "HTTP/2.0",
318 Version::HTTP_3 => "HTTP/3.0",
319 _ => return Err(io::Error::other("Unsupported version")),
320 }
321 )
322 .map_err(io::Error::other)
323 }
324}
325
326impl<T: MessageType> MessageEncoder<T> {
327 pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
329 self.te.encode(msg, buf)
330 }
331
332 pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
334 self.te.encode_eof(buf)
335 }
336
337 pub fn encode(
339 &mut self,
340 dst: &mut BytesMut,
341 message: &mut T,
342 head: bool,
343 stream: bool,
344 version: Version,
345 length: BodySize,
346 conn_type: ConnectionType,
347 config: &ServiceConfig,
348 ) -> io::Result<()> {
349 if !head {
351 self.te = match length {
352 BodySize::Sized(0) => TransferEncoding::empty(),
353 BodySize::Sized(len) => TransferEncoding::length(len),
354 BodySize::Stream => {
355 if message.chunked() && !stream {
356 TransferEncoding::chunked()
357 } else {
358 TransferEncoding::eof()
359 }
360 }
361 BodySize::None => TransferEncoding::empty(),
362 };
363 } else {
364 self.te = TransferEncoding::empty();
365 }
366
367 message.encode_status(dst)?;
368 message.encode_headers(dst, version, length, conn_type, config)
369 }
370}
371
372#[derive(Debug)]
374pub(crate) struct TransferEncoding {
375 kind: TransferEncodingKind,
376}
377
378#[derive(Debug, PartialEq, Clone)]
379enum TransferEncodingKind {
380 Chunked(bool),
382
383 Length(u64),
387
388 Eof,
392}
393
394impl TransferEncoding {
395 #[inline]
396 pub fn empty() -> TransferEncoding {
397 TransferEncoding {
398 kind: TransferEncodingKind::Length(0),
399 }
400 }
401
402 #[inline]
403 pub fn eof() -> TransferEncoding {
404 TransferEncoding {
405 kind: TransferEncodingKind::Eof,
406 }
407 }
408
409 #[inline]
410 pub fn chunked() -> TransferEncoding {
411 TransferEncoding {
412 kind: TransferEncodingKind::Chunked(false),
413 }
414 }
415
416 #[inline]
417 pub fn length(len: u64) -> TransferEncoding {
418 TransferEncoding {
419 kind: TransferEncodingKind::Length(len),
420 }
421 }
422
423 #[inline]
425 pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
426 match self.kind {
427 TransferEncodingKind::Eof => {
428 let eof = msg.is_empty();
429 buf.extend_from_slice(msg);
430 Ok(eof)
431 }
432 TransferEncodingKind::Chunked(ref mut eof) => {
433 if *eof {
434 return Ok(true);
435 }
436
437 if msg.is_empty() {
438 *eof = true;
439 buf.extend_from_slice(b"0\r\n\r\n");
440 } else {
441 writeln!(helpers::MutWriter(buf), "{:X}\r", msg.len())
442 .map_err(io::Error::other)?;
443
444 buf.reserve(msg.len() + 2);
445 buf.extend_from_slice(msg);
446 buf.extend_from_slice(b"\r\n");
447 }
448 Ok(*eof)
449 }
450 TransferEncodingKind::Length(ref mut remaining) => {
451 if *remaining > 0 {
452 if msg.is_empty() {
453 return Ok(*remaining == 0);
454 }
455 let len = cmp::min(*remaining, msg.len() as u64);
456
457 buf.extend_from_slice(&msg[..len as usize]);
458
459 *remaining -= len;
460 Ok(*remaining == 0)
461 } else {
462 Ok(true)
463 }
464 }
465 }
466 }
467
468 #[inline]
470 pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
471 match self.kind {
472 TransferEncodingKind::Eof => Ok(()),
473 TransferEncodingKind::Length(rem) => {
474 if rem != 0 {
475 Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
476 } else {
477 Ok(())
478 }
479 }
480 TransferEncodingKind::Chunked(ref mut eof) => {
481 if !*eof {
482 *eof = true;
483 buf.extend_from_slice(b"0\r\n\r\n");
484 }
485 Ok(())
486 }
487 }
488 }
489}
490
491unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
495 debug_assert_eq!(value.len(), len);
496 copy_nonoverlapping(value.as_ptr(), buf, len);
497}
498
499unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) {
503 write_data(value, buf, len);
505
506 let buffer = from_raw_parts_mut(buf, len);
508
509 let mut iter = value.iter();
510
511 if let Some(c @ b'a'..=b'z') = iter.next() {
513 buffer[0] = c & 0b1101_1111;
514 }
515
516 let mut index = 2;
518
519 while let Some(&c) = iter.next() {
521 if c == b'-' {
522 if let Some(c @ b'a'..=b'z') = iter.next() {
524 buffer[index] = c & 0b1101_1111;
525 }
526 index += 1;
527 }
528
529 index += 1;
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use std::rc::Rc;
536
537 use bytes::Bytes;
538 use http::header::{AUTHORIZATION, UPGRADE_INSECURE_REQUESTS};
539
540 use super::*;
541 use crate::{
542 header::{HeaderValue, CONTENT_TYPE},
543 RequestHead,
544 };
545
546 #[test]
547 fn test_chunked_te() {
548 let mut bytes = BytesMut::new();
549 let mut enc = TransferEncoding::chunked();
550 {
551 assert!(!enc.encode(b"test", &mut bytes).ok().unwrap());
552 assert!(enc.encode(b"", &mut bytes).ok().unwrap());
553 }
554 assert_eq!(
555 bytes.split().freeze(),
556 Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
557 );
558 }
559
560 #[actix_rt::test]
561 async fn test_camel_case() {
562 let mut bytes = BytesMut::with_capacity(2048);
563 let mut head = RequestHead::default();
564 head.set_camel_case_headers(true);
565 head.headers.insert(DATE, HeaderValue::from_static("date"));
566 head.headers
567 .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
568
569 head.headers
570 .insert(UPGRADE_INSECURE_REQUESTS, HeaderValue::from_static("1"));
571
572 let mut head = RequestHeadType::Owned(head);
573
574 let _ = head.encode_headers(
575 &mut bytes,
576 Version::HTTP_11,
577 BodySize::Sized(0),
578 ConnectionType::Close,
579 &ServiceConfig::default(),
580 );
581 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
582
583 assert!(data.contains("Content-Length: 0\r\n"));
584 assert!(data.contains("Connection: close\r\n"));
585 assert!(data.contains("Content-Type: plain/text\r\n"));
586 assert!(data.contains("Date: date\r\n"));
587 assert!(data.contains("Upgrade-Insecure-Requests: 1\r\n"));
588
589 let _ = head.encode_headers(
590 &mut bytes,
591 Version::HTTP_11,
592 BodySize::None,
593 ConnectionType::Upgrade,
594 &ServiceConfig::default(),
595 );
596 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
597 assert!(data.contains("Connection: Upgrade\r\n"));
598
599 let _ = head.encode_headers(
600 &mut bytes,
601 Version::HTTP_11,
602 BodySize::Stream,
603 ConnectionType::KeepAlive,
604 &ServiceConfig::default(),
605 );
606 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
607 assert!(data.contains("Transfer-Encoding: chunked\r\n"));
608 assert!(data.contains("Content-Type: plain/text\r\n"));
609 assert!(data.contains("Date: date\r\n"));
610
611 let mut head = RequestHead::default();
612 head.set_camel_case_headers(false);
613 head.headers.insert(DATE, HeaderValue::from_static("date"));
614 head.headers
615 .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
616 head.headers
617 .append(CONTENT_TYPE, HeaderValue::from_static("xml"));
618
619 let mut head = RequestHeadType::Owned(head);
620 let _ = head.encode_headers(
621 &mut bytes,
622 Version::HTTP_11,
623 BodySize::Stream,
624 ConnectionType::KeepAlive,
625 &ServiceConfig::default(),
626 );
627 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
628 assert!(data.contains("transfer-encoding: chunked\r\n"));
629 assert!(data.contains("content-type: xml\r\n"));
630 assert!(data.contains("content-type: plain/text\r\n"));
631 assert!(data.contains("date: date\r\n"));
632 }
633
634 #[actix_rt::test]
635 async fn test_extra_headers() {
636 let mut bytes = BytesMut::with_capacity(2048);
637
638 let mut head = RequestHead::default();
639 head.headers.insert(
640 AUTHORIZATION,
641 HeaderValue::from_static("some authorization"),
642 );
643
644 let mut extra_headers = HeaderMap::new();
645 extra_headers.insert(
646 AUTHORIZATION,
647 HeaderValue::from_static("another authorization"),
648 );
649 extra_headers.insert(DATE, HeaderValue::from_static("date"));
650
651 let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
652
653 let _ = head.encode_headers(
654 &mut bytes,
655 Version::HTTP_11,
656 BodySize::Sized(0),
657 ConnectionType::Close,
658 &ServiceConfig::default(),
659 );
660 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
661 assert!(data.contains("content-length: 0\r\n"));
662 assert!(data.contains("connection: close\r\n"));
663 assert!(data.contains("authorization: another authorization\r\n"));
664 assert!(data.contains("date: date\r\n"));
665 }
666
667 #[actix_rt::test]
668 async fn test_no_content_length() {
669 let mut bytes = BytesMut::with_capacity(2048);
670
671 let mut res = Response::with_body(StatusCode::SWITCHING_PROTOCOLS, ());
672 res.headers_mut().insert(DATE, HeaderValue::from_static(""));
673 res.headers_mut()
674 .insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
675
676 let _ = res.encode_headers(
677 &mut bytes,
678 Version::HTTP_11,
679 BodySize::Stream,
680 ConnectionType::Upgrade,
681 &ServiceConfig::default(),
682 );
683 let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
684 assert!(!data.contains("content-length: 0\r\n"));
685 assert!(!data.contains("transfer-encoding: chunked\r\n"));
686 }
687}