1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io, mem, net,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11use actix_codec::{Framed, FramedParts};
12use actix_rt::time::sleep_until;
13use actix_service::Service;
14use bitflags::bitflags;
15use bytes::{Buf, BytesMut};
16use futures_core::ready;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::{Decoder as _, Encoder as _};
20use tracing::{error, trace};
21
22use super::{
23 codec::Codec,
24 decoder::MAX_BUFFER_SIZE,
25 payload::{Payload, PayloadSender, PayloadStatus},
26 timer::TimerState,
27 Message, MessageType,
28};
29use crate::{
30 body::{BodySize, BoxBody, MessageBody},
31 config::ServiceConfig,
32 error::{DispatchError, ParseError, PayloadError},
33 service::HttpFlow,
34 Error, Extensions, OnConnectData, Request, Response, StatusCode,
35};
36
37const LW_BUFFER_SIZE: usize = 1024;
38const HW_BUFFER_SIZE: usize = 1024 * 8;
39const MAX_PIPELINED_MESSAGES: usize = 16;
40
41bitflags! {
42 #[derive(Debug, Clone, Copy)]
43 pub struct Flags: u8 {
44 const STARTED = 0b0000_0001;
46
47 const FINISHED = 0b0000_0010;
49
50 const KEEP_ALIVE = 0b0000_0100;
52
53 const SHUTDOWN = 0b0000_1000;
55
56 const READ_DISCONNECT = 0b0001_0000;
58
59 const WRITE_DISCONNECT = 0b0010_0000;
61 }
62}
63
64#[cfg(not(test))]
70pin_project! {
71 pub struct Dispatcher<T, S, B, X, U>
73 where
74 S: Service<Request>,
75 S::Error: Into<Response<BoxBody>>,
76
77 B: MessageBody,
78
79 X: Service<Request, Response = Request>,
80 X::Error: Into<Response<BoxBody>>,
81
82 U: Service<(Request, Framed<T, Codec>), Response = ()>,
83 U::Error: fmt::Display,
84 {
85 #[pin]
86 inner: DispatcherState<T, S, B, X, U>,
87 }
88}
89
90#[cfg(test)]
91pin_project! {
92 pub struct Dispatcher<T, S, B, X, U>
94 where
95 S: Service<Request>,
96 S::Error: Into<Response<BoxBody>>,
97
98 B: MessageBody,
99
100 X: Service<Request, Response = Request>,
101 X::Error: Into<Response<BoxBody>>,
102
103 U: Service<(Request, Framed<T, Codec>), Response = ()>,
104 U::Error: fmt::Display,
105 {
106 #[pin]
107 pub(super) inner: DispatcherState<T, S, B, X, U>,
108
109 pub(super) poll_count: u64,
111 }
112}
113
114pin_project! {
115 #[project = DispatcherStateProj]
116 pub(super) enum DispatcherState<T, S, B, X, U>
117 where
118 S: Service<Request>,
119 S::Error: Into<Response<BoxBody>>,
120
121 B: MessageBody,
122
123 X: Service<Request, Response = Request>,
124 X::Error: Into<Response<BoxBody>>,
125
126 U: Service<(Request, Framed<T, Codec>), Response = ()>,
127 U::Error: fmt::Display,
128 {
129 Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
130 Upgrade { #[pin] fut: U::Future },
131 }
132}
133
134pin_project! {
135 #[project = InnerDispatcherProj]
136 pub(super) struct InnerDispatcher<T, S, B, X, U>
137 where
138 S: Service<Request>,
139 S::Error: Into<Response<BoxBody>>,
140
141 B: MessageBody,
142
143 X: Service<Request, Response = Request>,
144 X::Error: Into<Response<BoxBody>>,
145
146 U: Service<(Request, Framed<T, Codec>), Response = ()>,
147 U::Error: fmt::Display,
148 {
149 flow: Rc<HttpFlow<S, X, U>>,
150 pub(super) flags: Flags,
151 peer_addr: Option<net::SocketAddr>,
152 conn_data: Option<Rc<Extensions>>,
153 config: ServiceConfig,
154 error: Option<DispatchError>,
155
156 #[pin]
157 pub(super) state: State<S, B, X>,
158 payload: Option<PayloadSender>,
160 messages: VecDeque<DispatcherMessage>,
161
162 head_timer: TimerState,
163 ka_timer: TimerState,
164 shutdown_timer: TimerState,
165
166 pub(super) io: Option<T>,
167 read_buf: BytesMut,
168 write_buf: BytesMut,
169 codec: Codec,
170 }
171}
172
173enum DispatcherMessage {
174 Item(Request),
175 Upgrade(Request),
176 Error(Response<()>),
177}
178
179pin_project! {
180 #[project = StateProj]
181 pub(super) enum State<S, B, X>
182 where
183 S: Service<Request>,
184 X: Service<Request, Response = Request>,
185 B: MessageBody,
186 {
187 None,
188 ExpectCall { #[pin] fut: X::Future },
189 ServiceCall { #[pin] fut: S::Future },
190 SendPayload { #[pin] body: B },
191 SendErrorPayload { #[pin] body: BoxBody },
192 }
193}
194
195impl<S, B, X> State<S, B, X>
196where
197 S: Service<Request>,
198 X: Service<Request, Response = Request>,
199 B: MessageBody,
200{
201 pub(super) fn is_none(&self) -> bool {
202 matches!(self, State::None)
203 }
204}
205
206impl<S, B, X> fmt::Debug for State<S, B, X>
207where
208 S: Service<Request>,
209 X: Service<Request, Response = Request>,
210 B: MessageBody,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 match self {
214 Self::None => write!(f, "State::None"),
215 Self::ExpectCall { .. } => f.debug_struct("State::ExpectCall").finish_non_exhaustive(),
216 Self::ServiceCall { .. } => {
217 f.debug_struct("State::ServiceCall").finish_non_exhaustive()
218 }
219 Self::SendPayload { .. } => {
220 f.debug_struct("State::SendPayload").finish_non_exhaustive()
221 }
222 Self::SendErrorPayload { .. } => f
223 .debug_struct("State::SendErrorPayload")
224 .finish_non_exhaustive(),
225 }
226 }
227}
228
229#[derive(Debug)]
230enum PollResponse {
231 Upgrade(Request),
232 DoNothing,
233 DrainWriteBuf,
234}
235
236impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
237where
238 T: AsyncRead + AsyncWrite + Unpin,
239
240 S: Service<Request>,
241 S::Error: Into<Response<BoxBody>>,
242 S::Response: Into<Response<B>>,
243
244 B: MessageBody,
245
246 X: Service<Request, Response = Request>,
247 X::Error: Into<Response<BoxBody>>,
248
249 U: Service<(Request, Framed<T, Codec>), Response = ()>,
250 U::Error: fmt::Display,
251{
252 pub(crate) fn new(
254 io: T,
255 flow: Rc<HttpFlow<S, X, U>>,
256 config: ServiceConfig,
257 peer_addr: Option<net::SocketAddr>,
258 conn_data: OnConnectData,
259 ) -> Self {
260 Dispatcher {
261 inner: DispatcherState::Normal {
262 inner: InnerDispatcher {
263 flow,
264 flags: Flags::empty(),
265 peer_addr,
266 conn_data: conn_data.0.map(Rc::new),
267 config: config.clone(),
268 error: None,
269
270 state: State::None,
271 payload: None,
272 messages: VecDeque::new(),
273
274 head_timer: TimerState::new(config.client_request_deadline().is_some()),
275 ka_timer: TimerState::new(config.keep_alive().enabled()),
276 shutdown_timer: TimerState::new(config.client_disconnect_deadline().is_some()),
277
278 io: Some(io),
279 read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
280 write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
281 codec: Codec::new(config),
282 },
283 },
284
285 #[cfg(test)]
286 poll_count: 0,
287 }
288 }
289}
290
291impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
292where
293 T: AsyncRead + AsyncWrite + Unpin,
294
295 S: Service<Request>,
296 S::Error: Into<Response<BoxBody>>,
297 S::Response: Into<Response<B>>,
298
299 B: MessageBody,
300
301 X: Service<Request, Response = Request>,
302 X::Error: Into<Response<BoxBody>>,
303
304 U: Service<(Request, Framed<T, Codec>), Response = ()>,
305 U::Error: fmt::Display,
306{
307 fn can_read(&self, cx: &mut Context<'_>) -> bool {
308 if self.flags.contains(Flags::READ_DISCONNECT) {
309 false
310 } else if let Some(ref info) = self.payload {
311 info.need_read(cx) == PayloadStatus::Read
312 } else {
313 true
314 }
315 }
316
317 fn client_disconnected(self: Pin<&mut Self>) {
318 let this = self.project();
319
320 this.flags
321 .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
322
323 if let Some(mut payload) = this.payload.take() {
324 payload.set_error(PayloadError::Incomplete(None));
325 }
326 }
327
328 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
329 let InnerDispatcherProj { io, write_buf, .. } = self.project();
330 let mut io = Pin::new(io.as_mut().unwrap());
331
332 let len = write_buf.len();
333 let mut written = 0;
334
335 while written < len {
336 match io.as_mut().poll_write(cx, &write_buf[written..])? {
337 Poll::Ready(0) => {
338 error!("write zero; closing");
339 return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
340 }
341
342 Poll::Ready(n) => written += n,
343
344 Poll::Pending => {
345 write_buf.advance(written);
346 return Poll::Pending;
347 }
348 }
349 }
350
351 write_buf.clear();
353
354 io.poll_flush(cx)
356 }
357
358 fn send_response_inner(
359 self: Pin<&mut Self>,
360 res: Response<()>,
361 body: &impl MessageBody,
362 ) -> Result<BodySize, DispatchError> {
363 let this = self.project();
364
365 let size = body.size();
366
367 this.codec
368 .encode(Message::Item((res, size)), this.write_buf)
369 .map_err(|err| {
370 if let Some(mut payload) = this.payload.take() {
371 payload.set_error(PayloadError::Incomplete(None));
372 }
373
374 DispatchError::Io(err)
375 })?;
376
377 Ok(size)
378 }
379
380 fn send_response(
381 mut self: Pin<&mut Self>,
382 res: Response<()>,
383 body: B,
384 ) -> Result<(), DispatchError> {
385 let size = self.as_mut().send_response_inner(res, &body)?;
386 let mut this = self.project();
387 this.state.set(match size {
388 BodySize::None | BodySize::Sized(0) => {
389 let payload_unfinished = this.payload.is_some();
390
391 if payload_unfinished {
392 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
393 } else {
394 this.flags.insert(Flags::FINISHED);
395 }
396
397 State::None
398 }
399 _ => State::SendPayload { body },
400 });
401
402 Ok(())
403 }
404
405 fn send_error_response(
406 mut self: Pin<&mut Self>,
407 res: Response<()>,
408 body: BoxBody,
409 ) -> Result<(), DispatchError> {
410 let size = self.as_mut().send_response_inner(res, &body)?;
411 let mut this = self.project();
412 this.state.set(match size {
413 BodySize::None | BodySize::Sized(0) => {
414 let payload_unfinished = this.payload.is_some();
415
416 if payload_unfinished {
417 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
418 } else {
419 this.flags.insert(Flags::FINISHED);
420 }
421
422 State::None
423 }
424 _ => State::SendErrorPayload { body },
425 });
426
427 Ok(())
428 }
429
430 fn send_continue(self: Pin<&mut Self>) {
431 self.project()
432 .write_buf
433 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
434 }
435
436 fn poll_response(
437 mut self: Pin<&mut Self>,
438 cx: &mut Context<'_>,
439 ) -> Result<PollResponse, DispatchError> {
440 'res: loop {
441 let mut this = self.as_mut().project();
442 match this.state.as_mut().project() {
443 StateProj::None => match this.messages.pop_front() {
445 Some(DispatcherMessage::Item(req)) => {
447 if req.head().expect() {
449 let fut = this.flow.expect.call(req);
451 this.state.set(State::ExpectCall { fut });
452 } else {
453 let fut = this.flow.service.call(req);
455 this.state.set(State::ServiceCall { fut });
456 };
457 }
458
459 Some(DispatcherMessage::Error(res)) => {
461 self.as_mut().send_error_response(res, BoxBody::new(()))?;
465 }
466
467 Some(DispatcherMessage::Upgrade(req)) => return Ok(PollResponse::Upgrade(req)),
469
470 None => {
472 this.flags.set(Flags::KEEP_ALIVE, this.codec.keep_alive());
474
475 return Ok(PollResponse::DoNothing);
476 }
477 },
478
479 StateProj::ServiceCall { fut } => {
480 match fut.poll(cx) {
481 Poll::Ready(Ok(res)) => {
483 let (res, body) = res.into().replace_body(());
484 self.as_mut().send_response(res, body)?;
485 }
486
487 Poll::Ready(Err(err)) => {
489 let res: Response<BoxBody> = err.into();
490 let (res, body) = res.replace_body(());
491 self.as_mut().send_error_response(res, body)?;
492 }
493
494 Poll::Pending => {
497 if !self.as_mut().poll_request(cx)? {
500 return Ok(PollResponse::DoNothing);
501 }
502 }
504 }
505 }
506
507 StateProj::SendPayload { mut body } => {
508 while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
511 match body.as_mut().poll_next(cx) {
512 Poll::Ready(Some(Ok(item))) => {
513 this.codec
514 .encode(Message::Chunk(Some(item)), this.write_buf)?;
515 }
516
517 Poll::Ready(None) => {
518 this.codec.encode(Message::Chunk(None), this.write_buf)?;
519
520 let payload_unfinished = this.payload.is_some();
525 let not_pipelined = this.messages.is_empty();
526
527 this.state.set(State::None);
530
531 if not_pipelined && payload_unfinished {
532 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
533 } else {
534 this.flags.insert(Flags::FINISHED);
535 }
536
537 continue 'res;
538 }
539
540 Poll::Ready(Some(Err(err))) => {
541 let err = err.into();
542 tracing::error!("Response payload stream error: {err:?}");
543 this.flags.insert(Flags::FINISHED);
544 return Err(DispatchError::Body(err));
545 }
546
547 Poll::Pending => return Ok(PollResponse::DoNothing),
548 }
549 }
550
551 return Ok(PollResponse::DrainWriteBuf);
554 }
555
556 StateProj::SendErrorPayload { mut body } => {
557 while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
562 match body.as_mut().poll_next(cx) {
563 Poll::Ready(Some(Ok(item))) => {
564 this.codec
565 .encode(Message::Chunk(Some(item)), this.write_buf)?;
566 }
567
568 Poll::Ready(None) => {
569 this.codec.encode(Message::Chunk(None), this.write_buf)?;
570
571 let payload_unfinished = this.payload.is_some();
576 let not_pipelined = this.messages.is_empty();
577
578 this.state.set(State::None);
581
582 if not_pipelined && payload_unfinished {
583 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
584 } else {
585 this.flags.insert(Flags::FINISHED);
586 }
587
588 continue 'res;
589 }
590
591 Poll::Ready(Some(Err(err))) => {
592 tracing::error!("Response payload stream error: {err:?}");
593 this.flags.insert(Flags::FINISHED);
594 return Err(DispatchError::Body(
595 Error::new_body().with_cause(err).into(),
596 ));
597 }
598
599 Poll::Pending => return Ok(PollResponse::DoNothing),
600 }
601 }
602
603 return Ok(PollResponse::DrainWriteBuf);
606 }
607
608 StateProj::ExpectCall { fut } => {
609 trace!(" calling expect service");
610
611 match fut.poll(cx) {
612 Poll::Ready(Ok(req)) => {
615 this.write_buf
616 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
617 let fut = this.flow.service.call(req);
618 this.state.set(State::ServiceCall { fut });
619 }
620
621 Poll::Ready(Err(err)) => {
623 let res: Response<BoxBody> = err.into();
624 let (res, body) = res.replace_body(());
625 self.as_mut().send_error_response(res, body)?;
626 }
627
628 Poll::Pending => return Ok(PollResponse::DoNothing),
630 }
631 }
632 }
633 }
634 }
635
636 fn handle_request(
637 mut self: Pin<&mut Self>,
638 req: Request,
639 cx: &mut Context<'_>,
640 ) -> Result<(), DispatchError> {
641 {
643 let mut this = self.as_mut().project();
644
645 if req.head().expect() {
647 let fut = this.flow.expect.call(req);
649 this.state.set(State::ExpectCall { fut });
650 } else {
651 let fut = this.flow.service.call(req);
653 this.state.set(State::ServiceCall { fut });
654 };
655 };
656
657 loop {
659 match self.as_mut().project().state.project() {
660 StateProj::ExpectCall { fut } => {
661 match fut.poll(cx) {
662 Poll::Ready(Ok(req)) => {
664 self.as_mut().send_continue();
665
666 let mut this = self.as_mut().project();
667 let fut = this.flow.service.call(req);
668 this.state.set(State::ServiceCall { fut });
669
670 continue;
671 }
672
673 Poll::Ready(Err(err)) => {
677 let res: Response<BoxBody> = err.into();
678 let (res, body) = res.replace_body(());
679 return self.send_error_response(res, body);
680 }
681
682 Poll::Pending => return Ok(()),
685 }
686 }
687
688 StateProj::ServiceCall { fut } => {
689 return match fut.poll(cx) {
691 Poll::Ready(Ok(res)) => {
695 let (res, body) = res.into().replace_body(());
696 self.as_mut().send_response(res, body)
697 }
698
699 Poll::Pending => Ok(()),
701
702 Poll::Ready(Err(err)) => {
704 let res: Response<BoxBody> = err.into();
705 let (res, body) = res.replace_body(());
706 self.as_mut().send_error_response(res, body)
707 }
708 };
709 }
710
711 _ => {
712 unreachable!("State must be set to ServiceCall or ExceptCall in handle_request")
713 }
714 }
715 }
716 }
717
718 fn poll_request(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
722 let pipeline_queue_full = self.messages.len() >= MAX_PIPELINED_MESSAGES;
723 let can_not_read = !self.can_read(cx);
724
725 if pipeline_queue_full || can_not_read {
727 return Ok(false);
728 }
729
730 let mut this = self.as_mut().project();
731
732 let mut updated = false;
733
734 loop {
736 match this.codec.decode(this.read_buf) {
737 Ok(Some(msg)) => {
738 updated = true;
739
740 match msg {
741 Message::Item(mut req) => {
742 this.head_timer.clear(line!());
744
745 req.head_mut().peer_addr = *this.peer_addr;
746
747 req.conn_data.clone_from(this.conn_data);
748
749 match this.codec.message_type() {
750 MessageType::None => {}
752
753 MessageType::Stream if this.flow.upgrade.is_some() => {
757 this.messages.push_back(DispatcherMessage::Upgrade(req));
758 break;
759 }
760
761 MessageType::Payload | MessageType::Stream => {
763 let (sender, payload) = Payload::create(false);
769 *req.payload() = crate::Payload::H1 { payload };
770 *this.payload = Some(sender);
771 }
772 }
773
774 if this.state.is_none() {
776 self.as_mut().handle_request(req, cx)?;
777 this = self.as_mut().project();
778 } else {
779 this.messages.push_back(DispatcherMessage::Item(req));
780 }
781 }
782
783 Message::Chunk(Some(chunk)) => {
784 if let Some(ref mut payload) = this.payload {
785 payload.feed_data(chunk);
786 } else {
787 error!("Internal server error: unexpected payload chunk");
788 this.flags.insert(Flags::READ_DISCONNECT);
789 this.messages.push_back(DispatcherMessage::Error(
790 Response::internal_server_error().drop_body(),
791 ));
792 *this.error = Some(DispatchError::InternalError);
793 break;
794 }
795 }
796
797 Message::Chunk(None) => {
798 if let Some(mut payload) = this.payload.take() {
799 payload.feed_eof();
800 } else {
801 error!("Internal server error: unexpected eof");
802 this.flags.insert(Flags::READ_DISCONNECT);
803 this.messages.push_back(DispatcherMessage::Error(
804 Response::internal_server_error().drop_body(),
805 ));
806 *this.error = Some(DispatchError::InternalError);
807 break;
808 }
809 }
810 }
811 }
812
813 Ok(None) => break,
816
817 Err(ParseError::Io(err)) => {
818 trace!("I/O error: {}", &err);
819 self.as_mut().client_disconnected();
820 this = self.as_mut().project();
821 *this.error = Some(DispatchError::Io(err));
822 break;
823 }
824
825 Err(ParseError::TooLarge) => {
826 trace!("request head was too big; returning 431 response");
827
828 if let Some(mut payload) = this.payload.take() {
829 payload.set_error(PayloadError::Overflow);
830 }
831
832 this.messages
834 .push_back(DispatcherMessage::Error(Response::with_body(
835 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
836 (),
837 )));
838
839 this.flags.insert(Flags::READ_DISCONNECT);
840 *this.error = Some(ParseError::TooLarge.into());
841
842 break;
843 }
844
845 Err(err) => {
846 trace!("parse error {}", &err);
847
848 if let Some(mut payload) = this.payload.take() {
849 payload.set_error(PayloadError::EncodingCorrupted);
850 }
851
852 this.messages.push_back(DispatcherMessage::Error(
854 Response::bad_request().drop_body(),
855 ));
856
857 this.flags.insert(Flags::READ_DISCONNECT);
858 *this.error = Some(err.into());
859 break;
860 }
861 }
862 }
863
864 Ok(updated)
865 }
866
867 fn poll_head_timer(
868 mut self: Pin<&mut Self>,
869 cx: &mut Context<'_>,
870 ) -> Result<(), DispatchError> {
871 let this = self.as_mut().project();
872
873 if let TimerState::Active { timer } = this.head_timer {
874 if timer.as_mut().poll(cx).is_ready() {
875 trace!("timed out on slow request; replying with 408 and closing connection");
878
879 let _ = self.as_mut().send_error_response(
880 Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
881 BoxBody::new(()),
882 );
883
884 self.project().flags.insert(Flags::SHUTDOWN);
885 }
886 };
887
888 Ok(())
889 }
890
891 fn poll_ka_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
892 let this = self.as_mut().project();
893 if let TimerState::Active { timer } = this.ka_timer {
894 debug_assert!(
895 this.flags.contains(Flags::KEEP_ALIVE),
896 "keep-alive flag should be set when timer is active",
897 );
898 debug_assert!(
899 this.state.is_none(),
900 "dispatcher should not be in keep-alive phase if state is not none: {:?}",
901 this.state,
902 );
903
904 if timer.as_mut().poll(cx).is_ready() {
916 trace!("timer timed out; closing connection");
918 this.flags.insert(Flags::SHUTDOWN);
919
920 if let Some(deadline) = this.config.client_disconnect_deadline() {
921 this.shutdown_timer
923 .set_and_init(cx, sleep_until(deadline.into()), line!());
924 } else {
925 this.flags.insert(Flags::WRITE_DISCONNECT);
927 }
928 }
929 }
930
931 Ok(())
932 }
933
934 fn poll_shutdown_timer(
935 mut self: Pin<&mut Self>,
936 cx: &mut Context<'_>,
937 ) -> Result<(), DispatchError> {
938 let this = self.as_mut().project();
939 if let TimerState::Active { timer } = this.shutdown_timer {
940 debug_assert!(
941 this.flags.contains(Flags::SHUTDOWN),
942 "shutdown flag should be set when timer is active",
943 );
944
945 if timer.as_mut().poll(cx).is_ready() {
947 trace!("timed-out during shutdown");
948 return Err(DispatchError::DisconnectTimeout);
949 }
950 }
951
952 Ok(())
953 }
954
955 fn poll_timers(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
957 self.as_mut().poll_head_timer(cx)?;
958 self.as_mut().poll_ka_timer(cx)?;
959 self.as_mut().poll_shutdown_timer(cx)?;
960
961 Ok(())
962 }
963
964 #[inline(always)] fn read_available(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
971 let this = self.project();
972
973 if this.flags.contains(Flags::READ_DISCONNECT) {
974 return Ok(false);
975 };
976
977 let mut io = Pin::new(this.io.as_mut().unwrap());
978
979 let mut read_some = false;
980
981 loop {
982 if this.read_buf.len() >= MAX_BUFFER_SIZE {
984 match this.payload {
1003 Some(ref p) if p.need_read(cx) != PayloadStatus::Read => {}
1018 _ => cx.waker().wake_by_ref(),
1019 }
1020
1021 return Ok(false);
1022 }
1023
1024 let remaining = this.read_buf.capacity() - this.read_buf.len();
1026 if remaining < LW_BUFFER_SIZE {
1027 this.read_buf.reserve(HW_BUFFER_SIZE - remaining);
1028 }
1029
1030 match tokio_util::io::poll_read_buf(io.as_mut(), cx, this.read_buf) {
1031 Poll::Ready(Ok(n)) => {
1032 this.flags.remove(Flags::FINISHED);
1033
1034 if n == 0 {
1035 return Ok(true);
1036 }
1037
1038 read_some = true;
1039 }
1040
1041 Poll::Pending => {
1042 return Ok(false);
1043 }
1044
1045 Poll::Ready(Err(err)) => {
1046 return match err.kind() {
1047 io::ErrorKind::WouldBlock => Ok(false),
1049
1050 io::ErrorKind::ConnectionReset if read_some => Ok(true),
1052
1053 _ => Err(DispatchError::Io(err)),
1054 };
1055 }
1056 }
1057 }
1058 }
1059
1060 fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future {
1062 let this = self.project();
1063 let mut parts = FramedParts::with_read_buf(
1064 this.io.take().unwrap(),
1065 mem::take(this.codec),
1066 mem::take(this.read_buf),
1067 );
1068 parts.write_buf = mem::take(this.write_buf);
1069 let framed = Framed::from_parts(parts);
1070 this.flow.upgrade.as_ref().unwrap().call((req, framed))
1071 }
1072}
1073
1074impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
1075where
1076 T: AsyncRead + AsyncWrite + Unpin,
1077
1078 S: Service<Request>,
1079 S::Error: Into<Response<BoxBody>>,
1080 S::Response: Into<Response<B>>,
1081
1082 B: MessageBody,
1083
1084 X: Service<Request, Response = Request>,
1085 X::Error: Into<Response<BoxBody>>,
1086
1087 U: Service<(Request, Framed<T, Codec>), Response = ()>,
1088 U::Error: fmt::Display,
1089{
1090 type Output = Result<(), DispatchError>;
1091
1092 #[inline]
1093 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1094 let this = self.as_mut().project();
1095
1096 #[cfg(test)]
1097 {
1098 *this.poll_count += 1;
1099 }
1100
1101 match this.inner.project() {
1102 DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
1103 error!("Upgrade handler error: {}", err);
1104 DispatchError::Upgrade
1105 }),
1106
1107 DispatcherStateProj::Normal { mut inner } => {
1108 trace!("start flags: {:?}", &inner.flags);
1109
1110 trace_timer_states(
1111 "start",
1112 &inner.head_timer,
1113 &inner.ka_timer,
1114 &inner.shutdown_timer,
1115 );
1116
1117 inner.as_mut().poll_timers(cx)?;
1118
1119 let poll = if inner.flags.contains(Flags::SHUTDOWN) {
1120 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1121 Poll::Ready(Ok(()))
1122 } else {
1123 ready!(inner.as_mut().poll_flush(cx))?;
1125 Pin::new(inner.as_mut().project().io.as_mut().unwrap())
1126 .poll_shutdown(cx)
1127 .map_err(DispatchError::from)
1128 }
1129 } else {
1130 let should_disconnect = inner.as_mut().read_available(cx)?;
1132
1133 if !inner.read_buf.is_empty() && inner.flags.contains(Flags::KEEP_ALIVE) {
1135 let inner = inner.as_mut().project();
1136 inner.flags.remove(Flags::KEEP_ALIVE);
1137 inner.ka_timer.clear(line!());
1138 }
1139
1140 if !inner.flags.contains(Flags::STARTED) {
1141 inner.as_mut().project().flags.insert(Flags::STARTED);
1142
1143 if let Some(deadline) = inner.config.client_request_deadline() {
1144 inner.as_mut().project().head_timer.set_and_init(
1145 cx,
1146 sleep_until(deadline.into()),
1147 line!(),
1148 );
1149 }
1150 }
1151
1152 inner.as_mut().poll_request(cx)?;
1153
1154 if should_disconnect {
1155 let inner = inner.as_mut().project();
1157 inner.flags.insert(Flags::READ_DISCONNECT);
1158 if let Some(mut payload) = inner.payload.take() {
1159 payload.feed_eof();
1160 }
1161 };
1162
1163 loop {
1164 let drain = match inner.as_mut().poll_response(cx)? {
1167 PollResponse::DrainWriteBuf => true,
1168
1169 PollResponse::DoNothing => {
1170 if inner.flags.contains(Flags::KEEP_ALIVE | Flags::FINISHED) {
1173 if let Some(timer) = inner.config.keep_alive_deadline() {
1174 inner.as_mut().project().ka_timer.set_and_init(
1175 cx,
1176 sleep_until(timer.into()),
1177 line!(),
1178 );
1179 }
1180 }
1181
1182 false
1183 }
1184
1185 PollResponse::Upgrade(req) => {
1187 let upgrade = inner.upgrade(req);
1188 self.as_mut()
1189 .project()
1190 .inner
1191 .set(DispatcherState::Upgrade { fut: upgrade });
1192 return self.poll(cx);
1193 }
1194 };
1195
1196 let flush_was_ready = inner.as_mut().poll_flush(cx)?.is_ready();
1203
1204 if !flush_was_ready || !drain {
1209 break;
1210 }
1211 }
1212
1213 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1215 trace!("client is gone; disconnecting");
1216 return Poll::Ready(Ok(()));
1217 }
1218
1219 let inner_p = inner.as_mut().project();
1220 let state_is_none = inner_p.state.is_none();
1221
1222 if inner_p.flags.contains(Flags::READ_DISCONNECT)
1230 && (!inner_p.config.h1_allow_half_closed() || state_is_none)
1231 {
1232 trace!("read half closed; start shutdown");
1233 inner_p.flags.insert(Flags::SHUTDOWN);
1234 }
1235
1236 if state_is_none && inner_p.write_buf.is_empty() {
1238 if let Some(err) = inner_p.error.take() {
1239 error!("stream error: {}", &err);
1240 return Poll::Ready(Err(err));
1241 }
1242
1243 if inner_p.flags.contains(Flags::FINISHED)
1245 && !inner_p.flags.contains(Flags::KEEP_ALIVE)
1246 {
1247 inner_p.flags.remove(Flags::FINISHED);
1248 inner_p.flags.insert(Flags::SHUTDOWN);
1249 return self.poll(cx);
1250 }
1251
1252 if inner_p.flags.contains(Flags::SHUTDOWN) {
1254 return self.poll(cx);
1255 }
1256 }
1257
1258 trace_timer_states(
1259 "end",
1260 inner_p.head_timer,
1261 inner_p.ka_timer,
1262 inner_p.shutdown_timer,
1263 );
1264
1265 if inner_p.flags.contains(Flags::SHUTDOWN) {
1266 cx.waker().wake_by_ref();
1267 }
1268 Poll::Pending
1269 };
1270
1271 trace!("end flags: {:?}", &inner.flags);
1272
1273 poll
1274 }
1275 }
1276 }
1277}
1278
1279#[allow(dead_code)]
1280fn trace_timer_states(
1281 label: &str,
1282 head_timer: &TimerState,
1283 ka_timer: &TimerState,
1284 shutdown_timer: &TimerState,
1285) {
1286 trace!("{} timers:", label);
1287
1288 if head_timer.is_enabled() {
1289 trace!(" head {}", &head_timer);
1290 }
1291
1292 if ka_timer.is_enabled() {
1293 trace!(" keep-alive {}", &ka_timer);
1294 }
1295
1296 if shutdown_timer.is_enabled() {
1297 trace!(" shutdown {}", &shutdown_timer);
1298 }
1299}