actix_http/h1/
dispatcher.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io, mem, net,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11use actix_codec::{Framed, FramedParts};
12use actix_rt::time::sleep_until;
13use actix_service::Service;
14use bitflags::bitflags;
15use bytes::{Buf, BytesMut};
16use futures_core::ready;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::{Decoder as _, Encoder as _};
20use tracing::{error, trace};
21
22use super::{
23    codec::Codec,
24    decoder::MAX_BUFFER_SIZE,
25    payload::{Payload, PayloadSender, PayloadStatus},
26    timer::TimerState,
27    Message, MessageType,
28};
29use crate::{
30    body::{BodySize, BoxBody, MessageBody},
31    config::ServiceConfig,
32    error::{DispatchError, ParseError, PayloadError},
33    service::HttpFlow,
34    Error, Extensions, OnConnectData, Request, Response, StatusCode,
35};
36
37const LW_BUFFER_SIZE: usize = 1024;
38const HW_BUFFER_SIZE: usize = 1024 * 8;
39const MAX_PIPELINED_MESSAGES: usize = 16;
40
41bitflags! {
42    #[derive(Debug, Clone, Copy)]
43    pub struct Flags: u8 {
44        /// Set when stream is read for first time.
45        const STARTED          = 0b0000_0001;
46
47        /// Set when full request-response cycle has occurred.
48        const FINISHED         = 0b0000_0010;
49
50        /// Set if connection is in keep-alive (inactive) state.
51        const KEEP_ALIVE       = 0b0000_0100;
52
53        /// Set if in shutdown procedure.
54        const SHUTDOWN         = 0b0000_1000;
55
56        /// Set if read-half is disconnected.
57        const READ_DISCONNECT  = 0b0001_0000;
58
59        /// Set if write-half is disconnected.
60        const WRITE_DISCONNECT = 0b0010_0000;
61    }
62}
63
64// there's 2 versions of Dispatcher state because of:
65// https://github.com/taiki-e/pin-project-lite/issues/3
66//
67// tl;dr: pin-project-lite doesn't play well with other attribute macros
68
69#[cfg(not(test))]
70pin_project! {
71    /// Dispatcher for HTTP/1.1 protocol
72    pub struct Dispatcher<T, S, B, X, U>
73    where
74        S: Service<Request>,
75        S::Error: Into<Response<BoxBody>>,
76
77        B: MessageBody,
78
79        X: Service<Request, Response = Request>,
80        X::Error: Into<Response<BoxBody>>,
81
82        U: Service<(Request, Framed<T, Codec>), Response = ()>,
83        U::Error: fmt::Display,
84    {
85        #[pin]
86        inner: DispatcherState<T, S, B, X, U>,
87    }
88}
89
90#[cfg(test)]
91pin_project! {
92    /// Dispatcher for HTTP/1.1 protocol
93    pub struct Dispatcher<T, S, B, X, U>
94    where
95        S: Service<Request>,
96        S::Error: Into<Response<BoxBody>>,
97
98        B: MessageBody,
99
100        X: Service<Request, Response = Request>,
101        X::Error: Into<Response<BoxBody>>,
102
103        U: Service<(Request, Framed<T, Codec>), Response = ()>,
104        U::Error: fmt::Display,
105    {
106        #[pin]
107        pub(super) inner: DispatcherState<T, S, B, X, U>,
108
109        // used in tests
110        pub(super) poll_count: u64,
111    }
112}
113
114pin_project! {
115    #[project = DispatcherStateProj]
116    pub(super) enum DispatcherState<T, S, B, X, U>
117    where
118        S: Service<Request>,
119        S::Error: Into<Response<BoxBody>>,
120
121        B: MessageBody,
122
123        X: Service<Request, Response = Request>,
124        X::Error: Into<Response<BoxBody>>,
125
126        U: Service<(Request, Framed<T, Codec>), Response = ()>,
127        U::Error: fmt::Display,
128    {
129        Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
130        Upgrade { #[pin] fut: U::Future },
131    }
132}
133
134pin_project! {
135    #[project = InnerDispatcherProj]
136    pub(super) struct InnerDispatcher<T, S, B, X, U>
137    where
138        S: Service<Request>,
139        S::Error: Into<Response<BoxBody>>,
140
141        B: MessageBody,
142
143        X: Service<Request, Response = Request>,
144        X::Error: Into<Response<BoxBody>>,
145
146        U: Service<(Request, Framed<T, Codec>), Response = ()>,
147        U::Error: fmt::Display,
148    {
149        flow: Rc<HttpFlow<S, X, U>>,
150        pub(super) flags: Flags,
151        peer_addr: Option<net::SocketAddr>,
152        conn_data: Option<Rc<Extensions>>,
153        config: ServiceConfig,
154        error: Option<DispatchError>,
155
156        #[pin]
157        pub(super) state: State<S, B, X>,
158        // when Some(_) dispatcher is in state of receiving request payload
159        payload: Option<PayloadSender>,
160        messages: VecDeque<DispatcherMessage>,
161
162        head_timer: TimerState,
163        ka_timer: TimerState,
164        shutdown_timer: TimerState,
165
166        pub(super) io: Option<T>,
167        read_buf: BytesMut,
168        write_buf: BytesMut,
169        codec: Codec,
170    }
171}
172
173enum DispatcherMessage {
174    Item(Request),
175    Upgrade(Request),
176    Error(Response<()>),
177}
178
179pin_project! {
180    #[project = StateProj]
181    pub(super) enum State<S, B, X>
182    where
183        S: Service<Request>,
184        X: Service<Request, Response = Request>,
185        B: MessageBody,
186    {
187        None,
188        ExpectCall { #[pin] fut: X::Future },
189        ServiceCall { #[pin] fut: S::Future },
190        SendPayload { #[pin] body: B },
191        SendErrorPayload { #[pin] body: BoxBody },
192    }
193}
194
195impl<S, B, X> State<S, B, X>
196where
197    S: Service<Request>,
198    X: Service<Request, Response = Request>,
199    B: MessageBody,
200{
201    pub(super) fn is_none(&self) -> bool {
202        matches!(self, State::None)
203    }
204}
205
206impl<S, B, X> fmt::Debug for State<S, B, X>
207where
208    S: Service<Request>,
209    X: Service<Request, Response = Request>,
210    B: MessageBody,
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            Self::None => write!(f, "State::None"),
215            Self::ExpectCall { .. } => f.debug_struct("State::ExpectCall").finish_non_exhaustive(),
216            Self::ServiceCall { .. } => {
217                f.debug_struct("State::ServiceCall").finish_non_exhaustive()
218            }
219            Self::SendPayload { .. } => {
220                f.debug_struct("State::SendPayload").finish_non_exhaustive()
221            }
222            Self::SendErrorPayload { .. } => f
223                .debug_struct("State::SendErrorPayload")
224                .finish_non_exhaustive(),
225        }
226    }
227}
228
229#[derive(Debug)]
230enum PollResponse {
231    Upgrade(Request),
232    DoNothing,
233    DrainWriteBuf,
234}
235
236impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
237where
238    T: AsyncRead + AsyncWrite + Unpin,
239
240    S: Service<Request>,
241    S::Error: Into<Response<BoxBody>>,
242    S::Response: Into<Response<B>>,
243
244    B: MessageBody,
245
246    X: Service<Request, Response = Request>,
247    X::Error: Into<Response<BoxBody>>,
248
249    U: Service<(Request, Framed<T, Codec>), Response = ()>,
250    U::Error: fmt::Display,
251{
252    /// Create HTTP/1 dispatcher.
253    pub(crate) fn new(
254        io: T,
255        flow: Rc<HttpFlow<S, X, U>>,
256        config: ServiceConfig,
257        peer_addr: Option<net::SocketAddr>,
258        conn_data: OnConnectData,
259    ) -> Self {
260        Dispatcher {
261            inner: DispatcherState::Normal {
262                inner: InnerDispatcher {
263                    flow,
264                    flags: Flags::empty(),
265                    peer_addr,
266                    conn_data: conn_data.0.map(Rc::new),
267                    config: config.clone(),
268                    error: None,
269
270                    state: State::None,
271                    payload: None,
272                    messages: VecDeque::new(),
273
274                    head_timer: TimerState::new(config.client_request_deadline().is_some()),
275                    ka_timer: TimerState::new(config.keep_alive().enabled()),
276                    shutdown_timer: TimerState::new(config.client_disconnect_deadline().is_some()),
277
278                    io: Some(io),
279                    read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
280                    write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
281                    codec: Codec::new(config),
282                },
283            },
284
285            #[cfg(test)]
286            poll_count: 0,
287        }
288    }
289}
290
291impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
292where
293    T: AsyncRead + AsyncWrite + Unpin,
294
295    S: Service<Request>,
296    S::Error: Into<Response<BoxBody>>,
297    S::Response: Into<Response<B>>,
298
299    B: MessageBody,
300
301    X: Service<Request, Response = Request>,
302    X::Error: Into<Response<BoxBody>>,
303
304    U: Service<(Request, Framed<T, Codec>), Response = ()>,
305    U::Error: fmt::Display,
306{
307    fn can_read(&self, cx: &mut Context<'_>) -> bool {
308        if self.flags.contains(Flags::READ_DISCONNECT) {
309            false
310        } else if let Some(ref info) = self.payload {
311            info.need_read(cx) == PayloadStatus::Read
312        } else {
313            true
314        }
315    }
316
317    fn client_disconnected(self: Pin<&mut Self>) {
318        let this = self.project();
319
320        this.flags
321            .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
322
323        if let Some(mut payload) = this.payload.take() {
324            payload.set_error(PayloadError::Incomplete(None));
325        }
326    }
327
328    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
329        let InnerDispatcherProj { io, write_buf, .. } = self.project();
330        let mut io = Pin::new(io.as_mut().unwrap());
331
332        let len = write_buf.len();
333        let mut written = 0;
334
335        while written < len {
336            match io.as_mut().poll_write(cx, &write_buf[written..])? {
337                Poll::Ready(0) => {
338                    error!("write zero; closing");
339                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
340                }
341
342                Poll::Ready(n) => written += n,
343
344                Poll::Pending => {
345                    write_buf.advance(written);
346                    return Poll::Pending;
347                }
348            }
349        }
350
351        // everything has written to I/O; clear buffer
352        write_buf.clear();
353
354        // flush the I/O and check if get blocked
355        io.poll_flush(cx)
356    }
357
358    fn send_response_inner(
359        self: Pin<&mut Self>,
360        res: Response<()>,
361        body: &impl MessageBody,
362    ) -> Result<BodySize, DispatchError> {
363        let this = self.project();
364
365        let size = body.size();
366
367        this.codec
368            .encode(Message::Item((res, size)), this.write_buf)
369            .map_err(|err| {
370                if let Some(mut payload) = this.payload.take() {
371                    payload.set_error(PayloadError::Incomplete(None));
372                }
373
374                DispatchError::Io(err)
375            })?;
376
377        Ok(size)
378    }
379
380    fn send_response(
381        mut self: Pin<&mut Self>,
382        res: Response<()>,
383        body: B,
384    ) -> Result<(), DispatchError> {
385        let size = self.as_mut().send_response_inner(res, &body)?;
386        let mut this = self.project();
387        this.state.set(match size {
388            BodySize::None | BodySize::Sized(0) => {
389                let payload_unfinished = this.payload.is_some();
390
391                if payload_unfinished {
392                    this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
393                } else {
394                    this.flags.insert(Flags::FINISHED);
395                }
396
397                State::None
398            }
399            _ => State::SendPayload { body },
400        });
401
402        Ok(())
403    }
404
405    fn send_error_response(
406        mut self: Pin<&mut Self>,
407        res: Response<()>,
408        body: BoxBody,
409    ) -> Result<(), DispatchError> {
410        let size = self.as_mut().send_response_inner(res, &body)?;
411        let mut this = self.project();
412        this.state.set(match size {
413            BodySize::None | BodySize::Sized(0) => {
414                let payload_unfinished = this.payload.is_some();
415
416                if payload_unfinished {
417                    this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
418                } else {
419                    this.flags.insert(Flags::FINISHED);
420                }
421
422                State::None
423            }
424            _ => State::SendErrorPayload { body },
425        });
426
427        Ok(())
428    }
429
430    fn send_continue(self: Pin<&mut Self>) {
431        self.project()
432            .write_buf
433            .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
434    }
435
436    fn poll_response(
437        mut self: Pin<&mut Self>,
438        cx: &mut Context<'_>,
439    ) -> Result<PollResponse, DispatchError> {
440        'res: loop {
441            let mut this = self.as_mut().project();
442            match this.state.as_mut().project() {
443                // no future is in InnerDispatcher state; pop next message
444                StateProj::None => match this.messages.pop_front() {
445                    // handle request message
446                    Some(DispatcherMessage::Item(req)) => {
447                        // Handle `EXPECT: 100-Continue` header
448                        if req.head().expect() {
449                            // set InnerDispatcher state and continue loop to poll it
450                            let fut = this.flow.expect.call(req);
451                            this.state.set(State::ExpectCall { fut });
452                        } else {
453                            // set InnerDispatcher state and continue loop to poll it
454                            let fut = this.flow.service.call(req);
455                            this.state.set(State::ServiceCall { fut });
456                        };
457                    }
458
459                    // handle error message
460                    Some(DispatcherMessage::Error(res)) => {
461                        // send_response would update InnerDispatcher state to SendPayload or None
462                        // (If response body is empty)
463                        // continue loop to poll it
464                        self.as_mut().send_error_response(res, BoxBody::new(()))?;
465                    }
466
467                    // return with upgrade request and poll it exclusively
468                    Some(DispatcherMessage::Upgrade(req)) => return Ok(PollResponse::Upgrade(req)),
469
470                    // all messages are dealt with
471                    None => {
472                        // start keep-alive if last request allowed it
473                        this.flags.set(Flags::KEEP_ALIVE, this.codec.keep_alive());
474
475                        return Ok(PollResponse::DoNothing);
476                    }
477                },
478
479                StateProj::ServiceCall { fut } => {
480                    match fut.poll(cx) {
481                        // service call resolved. send response.
482                        Poll::Ready(Ok(res)) => {
483                            let (res, body) = res.into().replace_body(());
484                            self.as_mut().send_response(res, body)?;
485                        }
486
487                        // send service call error as response
488                        Poll::Ready(Err(err)) => {
489                            let res: Response<BoxBody> = err.into();
490                            let (res, body) = res.replace_body(());
491                            self.as_mut().send_error_response(res, body)?;
492                        }
493
494                        // service call pending and could be waiting for more chunk messages
495                        // (pipeline message limit and/or payload can_read limit)
496                        Poll::Pending => {
497                            // no new message is decoded and no new payload is fed
498                            // nothing to do except waiting for new incoming data from client
499                            if !self.as_mut().poll_request(cx)? {
500                                return Ok(PollResponse::DoNothing);
501                            }
502                            // else loop
503                        }
504                    }
505                }
506
507                StateProj::SendPayload { mut body } => {
508                    // keep populate writer buffer until buffer size limit hit,
509                    // get blocked or finished.
510                    while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
511                        match body.as_mut().poll_next(cx) {
512                            Poll::Ready(Some(Ok(item))) => {
513                                this.codec
514                                    .encode(Message::Chunk(Some(item)), this.write_buf)?;
515                            }
516
517                            Poll::Ready(None) => {
518                                this.codec.encode(Message::Chunk(None), this.write_buf)?;
519
520                                // if we have not yet pipelined to the next request, then
521                                // this.payload was the payload for the request we just finished
522                                // responding to. We can check to see if we finished reading it
523                                // yet, and if not, shutdown the connection.
524                                let payload_unfinished = this.payload.is_some();
525                                let not_pipelined = this.messages.is_empty();
526
527                                // payload stream finished.
528                                // set state to None and handle next message
529                                this.state.set(State::None);
530
531                                if not_pipelined && payload_unfinished {
532                                    this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
533                                } else {
534                                    this.flags.insert(Flags::FINISHED);
535                                }
536
537                                continue 'res;
538                            }
539
540                            Poll::Ready(Some(Err(err))) => {
541                                let err = err.into();
542                                tracing::error!("Response payload stream error: {err:?}");
543                                this.flags.insert(Flags::FINISHED);
544                                return Err(DispatchError::Body(err));
545                            }
546
547                            Poll::Pending => return Ok(PollResponse::DoNothing),
548                        }
549                    }
550
551                    // buffer is beyond max size
552                    // return and try to write the whole buffer to I/O stream.
553                    return Ok(PollResponse::DrainWriteBuf);
554                }
555
556                StateProj::SendErrorPayload { mut body } => {
557                    // TODO: de-dupe impl with SendPayload
558
559                    // keep populate writer buffer until buffer size limit hit,
560                    // get blocked or finished.
561                    while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
562                        match body.as_mut().poll_next(cx) {
563                            Poll::Ready(Some(Ok(item))) => {
564                                this.codec
565                                    .encode(Message::Chunk(Some(item)), this.write_buf)?;
566                            }
567
568                            Poll::Ready(None) => {
569                                this.codec.encode(Message::Chunk(None), this.write_buf)?;
570
571                                // if we have not yet pipelined to the next request, then
572                                // this.payload was the payload for the request we just finished
573                                // responding to. We can check to see if we finished reading it
574                                // yet, and if not, shutdown the connection.
575                                let payload_unfinished = this.payload.is_some();
576                                let not_pipelined = this.messages.is_empty();
577
578                                // payload stream finished.
579                                // set state to None and handle next message
580                                this.state.set(State::None);
581
582                                if not_pipelined && payload_unfinished {
583                                    this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
584                                } else {
585                                    this.flags.insert(Flags::FINISHED);
586                                }
587
588                                continue 'res;
589                            }
590
591                            Poll::Ready(Some(Err(err))) => {
592                                tracing::error!("Response payload stream error: {err:?}");
593                                this.flags.insert(Flags::FINISHED);
594                                return Err(DispatchError::Body(
595                                    Error::new_body().with_cause(err).into(),
596                                ));
597                            }
598
599                            Poll::Pending => return Ok(PollResponse::DoNothing),
600                        }
601                    }
602
603                    // buffer is beyond max size
604                    // return and try to write the whole buffer to stream
605                    return Ok(PollResponse::DrainWriteBuf);
606                }
607
608                StateProj::ExpectCall { fut } => {
609                    trace!("  calling expect service");
610
611                    match fut.poll(cx) {
612                        // expect resolved. write continue to buffer and set InnerDispatcher state
613                        // to service call.
614                        Poll::Ready(Ok(req)) => {
615                            this.write_buf
616                                .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
617                            let fut = this.flow.service.call(req);
618                            this.state.set(State::ServiceCall { fut });
619                        }
620
621                        // send expect error as response
622                        Poll::Ready(Err(err)) => {
623                            let res: Response<BoxBody> = err.into();
624                            let (res, body) = res.replace_body(());
625                            self.as_mut().send_error_response(res, body)?;
626                        }
627
628                        // expect must be solved before progress can be made.
629                        Poll::Pending => return Ok(PollResponse::DoNothing),
630                    }
631                }
632            }
633        }
634    }
635
636    fn handle_request(
637        mut self: Pin<&mut Self>,
638        req: Request,
639        cx: &mut Context<'_>,
640    ) -> Result<(), DispatchError> {
641        // initialize dispatcher state
642        {
643            let mut this = self.as_mut().project();
644
645            // Handle `EXPECT: 100-Continue` header
646            if req.head().expect() {
647                // set dispatcher state to call expect handler
648                let fut = this.flow.expect.call(req);
649                this.state.set(State::ExpectCall { fut });
650            } else {
651                // set dispatcher state to call service handler
652                let fut = this.flow.service.call(req);
653                this.state.set(State::ServiceCall { fut });
654            };
655        };
656
657        // eagerly poll the future once (or twice if expect is resolved immediately).
658        loop {
659            match self.as_mut().project().state.project() {
660                StateProj::ExpectCall { fut } => {
661                    match fut.poll(cx) {
662                        // expect is resolved; continue loop and poll the service call branch.
663                        Poll::Ready(Ok(req)) => {
664                            self.as_mut().send_continue();
665
666                            let mut this = self.as_mut().project();
667                            let fut = this.flow.service.call(req);
668                            this.state.set(State::ServiceCall { fut });
669
670                            continue;
671                        }
672
673                        // future is error; send response and return a result
674                        // on success to notify the dispatcher a new state is set and the outer loop
675                        // should be continued
676                        Poll::Ready(Err(err)) => {
677                            let res: Response<BoxBody> = err.into();
678                            let (res, body) = res.replace_body(());
679                            return self.send_error_response(res, body);
680                        }
681
682                        // future is pending; return Ok(()) to notify that a new state is
683                        // set and the outer loop should be continue.
684                        Poll::Pending => return Ok(()),
685                    }
686                }
687
688                StateProj::ServiceCall { fut } => {
689                    // return no matter the service call future's result.
690                    return match fut.poll(cx) {
691                        // Future is resolved. Send response and return a result. On success
692                        // to notify the dispatcher a new state is set and the outer loop
693                        // should be continue.
694                        Poll::Ready(Ok(res)) => {
695                            let (res, body) = res.into().replace_body(());
696                            self.as_mut().send_response(res, body)
697                        }
698
699                        // see the comment on ExpectCall state branch's Pending
700                        Poll::Pending => Ok(()),
701
702                        // see the comment on ExpectCall state branch's Ready(Err(_))
703                        Poll::Ready(Err(err)) => {
704                            let res: Response<BoxBody> = err.into();
705                            let (res, body) = res.replace_body(());
706                            self.as_mut().send_error_response(res, body)
707                        }
708                    };
709                }
710
711                _ => {
712                    unreachable!("State must be set to ServiceCall or ExceptCall in handle_request")
713                }
714            }
715        }
716    }
717
718    /// Process one incoming request.
719    ///
720    /// Returns true if any meaningful work was done.
721    fn poll_request(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
722        let pipeline_queue_full = self.messages.len() >= MAX_PIPELINED_MESSAGES;
723        let can_not_read = !self.can_read(cx);
724
725        // limit amount of non-processed requests
726        if pipeline_queue_full || can_not_read {
727            return Ok(false);
728        }
729
730        let mut this = self.as_mut().project();
731
732        let mut updated = false;
733
734        // decode from read buf as many full requests as possible
735        loop {
736            match this.codec.decode(this.read_buf) {
737                Ok(Some(msg)) => {
738                    updated = true;
739
740                    match msg {
741                        Message::Item(mut req) => {
742                            // head timer only applies to first request on connection
743                            this.head_timer.clear(line!());
744
745                            req.head_mut().peer_addr = *this.peer_addr;
746
747                            req.conn_data.clone_from(this.conn_data);
748
749                            match this.codec.message_type() {
750                                // request has no payload
751                                MessageType::None => {}
752
753                                // Request is upgradable. Add upgrade message and break.
754                                // Everything remaining in read buffer will be handed to
755                                // upgraded Request.
756                                MessageType::Stream if this.flow.upgrade.is_some() => {
757                                    this.messages.push_back(DispatcherMessage::Upgrade(req));
758                                    break;
759                                }
760
761                                // request is not upgradable
762                                MessageType::Payload | MessageType::Stream => {
763                                    // PayloadSender and Payload are smart pointers share the
764                                    // same state. PayloadSender is attached to dispatcher and used
765                                    // to sink new chunked request data to state. Payload is
766                                    // attached to Request and passed to Service::call where the
767                                    // state can be collected and consumed.
768                                    let (sender, payload) = Payload::create(false);
769                                    *req.payload() = crate::Payload::H1 { payload };
770                                    *this.payload = Some(sender);
771                                }
772                            }
773
774                            // handle request early when no future in InnerDispatcher state.
775                            if this.state.is_none() {
776                                self.as_mut().handle_request(req, cx)?;
777                                this = self.as_mut().project();
778                            } else {
779                                this.messages.push_back(DispatcherMessage::Item(req));
780                            }
781                        }
782
783                        Message::Chunk(Some(chunk)) => {
784                            if let Some(ref mut payload) = this.payload {
785                                payload.feed_data(chunk);
786                            } else {
787                                error!("Internal server error: unexpected payload chunk");
788                                this.flags.insert(Flags::READ_DISCONNECT);
789                                this.messages.push_back(DispatcherMessage::Error(
790                                    Response::internal_server_error().drop_body(),
791                                ));
792                                *this.error = Some(DispatchError::InternalError);
793                                break;
794                            }
795                        }
796
797                        Message::Chunk(None) => {
798                            if let Some(mut payload) = this.payload.take() {
799                                payload.feed_eof();
800                            } else {
801                                error!("Internal server error: unexpected eof");
802                                this.flags.insert(Flags::READ_DISCONNECT);
803                                this.messages.push_back(DispatcherMessage::Error(
804                                    Response::internal_server_error().drop_body(),
805                                ));
806                                *this.error = Some(DispatchError::InternalError);
807                                break;
808                            }
809                        }
810                    }
811                }
812
813                // decode is partial and buffer is not full yet
814                // break and wait for more read
815                Ok(None) => break,
816
817                Err(ParseError::Io(err)) => {
818                    trace!("I/O error: {}", &err);
819                    self.as_mut().client_disconnected();
820                    this = self.as_mut().project();
821                    *this.error = Some(DispatchError::Io(err));
822                    break;
823                }
824
825                Err(ParseError::TooLarge) => {
826                    trace!("request head was too big; returning 431 response");
827
828                    if let Some(mut payload) = this.payload.take() {
829                        payload.set_error(PayloadError::Overflow);
830                    }
831
832                    // request heads that overflow buffer size return a 431 error
833                    this.messages
834                        .push_back(DispatcherMessage::Error(Response::with_body(
835                            StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
836                            (),
837                        )));
838
839                    this.flags.insert(Flags::READ_DISCONNECT);
840                    *this.error = Some(ParseError::TooLarge.into());
841
842                    break;
843                }
844
845                Err(err) => {
846                    trace!("parse error {}", &err);
847
848                    if let Some(mut payload) = this.payload.take() {
849                        payload.set_error(PayloadError::EncodingCorrupted);
850                    }
851
852                    // malformed requests should be responded with 400
853                    this.messages.push_back(DispatcherMessage::Error(
854                        Response::bad_request().drop_body(),
855                    ));
856
857                    this.flags.insert(Flags::READ_DISCONNECT);
858                    *this.error = Some(err.into());
859                    break;
860                }
861            }
862        }
863
864        Ok(updated)
865    }
866
867    fn poll_head_timer(
868        mut self: Pin<&mut Self>,
869        cx: &mut Context<'_>,
870    ) -> Result<(), DispatchError> {
871        let this = self.as_mut().project();
872
873        if let TimerState::Active { timer } = this.head_timer {
874            if timer.as_mut().poll(cx).is_ready() {
875                // timeout on first request (slow request) return 408
876
877                trace!("timed out on slow request; replying with 408 and closing connection");
878
879                let _ = self.as_mut().send_error_response(
880                    Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
881                    BoxBody::new(()),
882                );
883
884                self.project().flags.insert(Flags::SHUTDOWN);
885            }
886        };
887
888        Ok(())
889    }
890
891    fn poll_ka_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
892        let this = self.as_mut().project();
893        if let TimerState::Active { timer } = this.ka_timer {
894            debug_assert!(
895                this.flags.contains(Flags::KEEP_ALIVE),
896                "keep-alive flag should be set when timer is active",
897            );
898            debug_assert!(
899                this.state.is_none(),
900                "dispatcher should not be in keep-alive phase if state is not none: {:?}",
901                this.state,
902            );
903
904            // Assert removed by @robjtede on account of issue #2655. There are cases where an I/O
905            // flush can be pending after entering the keep-alive state causing the subsequent flush
906            // wake up to panic here. This appears to be a Linux-only problem. Leaving original code
907            // below for posterity because a simple and reliable test could not be found to trigger
908            // the behavior.
909            // debug_assert!(
910            //     this.write_buf.is_empty(),
911            //     "dispatcher should not be in keep-alive phase if write_buf is not empty",
912            // );
913
914            // keep-alive timer has timed out
915            if timer.as_mut().poll(cx).is_ready() {
916                // no tasks at hand
917                trace!("timer timed out; closing connection");
918                this.flags.insert(Flags::SHUTDOWN);
919
920                if let Some(deadline) = this.config.client_disconnect_deadline() {
921                    // start shutdown timeout if enabled
922                    this.shutdown_timer
923                        .set_and_init(cx, sleep_until(deadline.into()), line!());
924                } else {
925                    // no shutdown timeout, drop socket
926                    this.flags.insert(Flags::WRITE_DISCONNECT);
927                }
928            }
929        }
930
931        Ok(())
932    }
933
934    fn poll_shutdown_timer(
935        mut self: Pin<&mut Self>,
936        cx: &mut Context<'_>,
937    ) -> Result<(), DispatchError> {
938        let this = self.as_mut().project();
939        if let TimerState::Active { timer } = this.shutdown_timer {
940            debug_assert!(
941                this.flags.contains(Flags::SHUTDOWN),
942                "shutdown flag should be set when timer is active",
943            );
944
945            // timed-out during shutdown; drop connection
946            if timer.as_mut().poll(cx).is_ready() {
947                trace!("timed-out during shutdown");
948                return Err(DispatchError::DisconnectTimeout);
949            }
950        }
951
952        Ok(())
953    }
954
955    /// Poll head, keep-alive, and disconnect timer.
956    fn poll_timers(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
957        self.as_mut().poll_head_timer(cx)?;
958        self.as_mut().poll_ka_timer(cx)?;
959        self.as_mut().poll_shutdown_timer(cx)?;
960
961        Ok(())
962    }
963
964    /// Returns true when I/O stream can be disconnected after write to it.
965    ///
966    /// It covers these conditions:
967    /// - `std::io::ErrorKind::ConnectionReset` after partial read;
968    /// - all data read done.
969    #[inline(always)] // TODO: bench this inline
970    fn read_available(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
971        let this = self.project();
972
973        if this.flags.contains(Flags::READ_DISCONNECT) {
974            return Ok(false);
975        };
976
977        let mut io = Pin::new(this.io.as_mut().unwrap());
978
979        let mut read_some = false;
980
981        loop {
982            // Return early when read buf exceed decoder's max buffer size.
983            if this.read_buf.len() >= MAX_BUFFER_SIZE {
984                // At this point it's not known IO stream is still scheduled to be waked up so
985                // force wake up dispatcher just in case.
986                //
987                // Reason:
988                // AsyncRead mostly would only have guarantee wake up when the poll_read
989                // return Poll::Pending.
990                //
991                // Case:
992                // When read_buf is beyond max buffer size the early return could be successfully
993                // be parsed as a new Request. This case would not generate ParseError::TooLarge and
994                // at this point IO stream is not fully read to Pending and would result in
995                // dispatcher stuck until timeout (keep-alive).
996                //
997                // Note:
998                // This is a perf choice to reduce branch on <Request as MessageType>::decode.
999                //
1000                // A Request head too large to parse is only checked on `httparse::Status::Partial`.
1001
1002                match this.payload {
1003                    // When dispatcher has a payload the responsibility of wake ups is shifted to
1004                    // `h1::payload::Payload` unless the payload is needing a read, in which case it
1005                    // might not have access to the waker and could result in the dispatcher
1006                    // getting stuck until timeout.
1007                    //
1008                    // Reason:
1009                    // Self wake up when there is payload would waste poll and/or result in
1010                    // over read.
1011                    //
1012                    // Case:
1013                    // When payload is (partial) dropped by user there is no need to do
1014                    // read anymore. At this case read_buf could always remain beyond
1015                    // MAX_BUFFER_SIZE and self wake up would be busy poll dispatcher and
1016                    // waste resources.
1017                    Some(ref p) if p.need_read(cx) != PayloadStatus::Read => {}
1018                    _ => cx.waker().wake_by_ref(),
1019                }
1020
1021                return Ok(false);
1022            }
1023
1024            // grow buffer if necessary.
1025            let remaining = this.read_buf.capacity() - this.read_buf.len();
1026            if remaining < LW_BUFFER_SIZE {
1027                this.read_buf.reserve(HW_BUFFER_SIZE - remaining);
1028            }
1029
1030            match tokio_util::io::poll_read_buf(io.as_mut(), cx, this.read_buf) {
1031                Poll::Ready(Ok(n)) => {
1032                    this.flags.remove(Flags::FINISHED);
1033
1034                    if n == 0 {
1035                        return Ok(true);
1036                    }
1037
1038                    read_some = true;
1039                }
1040
1041                Poll::Pending => {
1042                    return Ok(false);
1043                }
1044
1045                Poll::Ready(Err(err)) => {
1046                    return match err.kind() {
1047                        // convert WouldBlock error to the same as Pending return
1048                        io::ErrorKind::WouldBlock => Ok(false),
1049
1050                        // connection reset after partial read
1051                        io::ErrorKind::ConnectionReset if read_some => Ok(true),
1052
1053                        _ => Err(DispatchError::Io(err)),
1054                    };
1055                }
1056            }
1057        }
1058    }
1059
1060    /// call upgrade service with request.
1061    fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future {
1062        let this = self.project();
1063        let mut parts = FramedParts::with_read_buf(
1064            this.io.take().unwrap(),
1065            mem::take(this.codec),
1066            mem::take(this.read_buf),
1067        );
1068        parts.write_buf = mem::take(this.write_buf);
1069        let framed = Framed::from_parts(parts);
1070        this.flow.upgrade.as_ref().unwrap().call((req, framed))
1071    }
1072}
1073
1074impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
1075where
1076    T: AsyncRead + AsyncWrite + Unpin,
1077
1078    S: Service<Request>,
1079    S::Error: Into<Response<BoxBody>>,
1080    S::Response: Into<Response<B>>,
1081
1082    B: MessageBody,
1083
1084    X: Service<Request, Response = Request>,
1085    X::Error: Into<Response<BoxBody>>,
1086
1087    U: Service<(Request, Framed<T, Codec>), Response = ()>,
1088    U::Error: fmt::Display,
1089{
1090    type Output = Result<(), DispatchError>;
1091
1092    #[inline]
1093    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1094        let this = self.as_mut().project();
1095
1096        #[cfg(test)]
1097        {
1098            *this.poll_count += 1;
1099        }
1100
1101        match this.inner.project() {
1102            DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
1103                error!("Upgrade handler error: {}", err);
1104                DispatchError::Upgrade
1105            }),
1106
1107            DispatcherStateProj::Normal { mut inner } => {
1108                trace!("start flags: {:?}", &inner.flags);
1109
1110                trace_timer_states(
1111                    "start",
1112                    &inner.head_timer,
1113                    &inner.ka_timer,
1114                    &inner.shutdown_timer,
1115                );
1116
1117                inner.as_mut().poll_timers(cx)?;
1118
1119                let poll = if inner.flags.contains(Flags::SHUTDOWN) {
1120                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1121                        Poll::Ready(Ok(()))
1122                    } else {
1123                        // flush buffer and wait on blocked
1124                        ready!(inner.as_mut().poll_flush(cx))?;
1125                        Pin::new(inner.as_mut().project().io.as_mut().unwrap())
1126                            .poll_shutdown(cx)
1127                            .map_err(DispatchError::from)
1128                    }
1129                } else {
1130                    // read from I/O stream and fill read buffer
1131                    let should_disconnect = inner.as_mut().read_available(cx)?;
1132
1133                    // after reading something from stream, clear keep-alive timer
1134                    if !inner.read_buf.is_empty() && inner.flags.contains(Flags::KEEP_ALIVE) {
1135                        let inner = inner.as_mut().project();
1136                        inner.flags.remove(Flags::KEEP_ALIVE);
1137                        inner.ka_timer.clear(line!());
1138                    }
1139
1140                    if !inner.flags.contains(Flags::STARTED) {
1141                        inner.as_mut().project().flags.insert(Flags::STARTED);
1142
1143                        if let Some(deadline) = inner.config.client_request_deadline() {
1144                            inner.as_mut().project().head_timer.set_and_init(
1145                                cx,
1146                                sleep_until(deadline.into()),
1147                                line!(),
1148                            );
1149                        }
1150                    }
1151
1152                    inner.as_mut().poll_request(cx)?;
1153
1154                    if should_disconnect {
1155                        // I/O stream should to be closed
1156                        let inner = inner.as_mut().project();
1157                        inner.flags.insert(Flags::READ_DISCONNECT);
1158                        if let Some(mut payload) = inner.payload.take() {
1159                            payload.feed_eof();
1160                        }
1161                    };
1162
1163                    loop {
1164                        // poll response to populate write buffer
1165                        // drain indicates whether write buffer should be emptied before next run
1166                        let drain = match inner.as_mut().poll_response(cx)? {
1167                            PollResponse::DrainWriteBuf => true,
1168
1169                            PollResponse::DoNothing => {
1170                                // KEEP_ALIVE is set in send_response_inner if client allows it
1171                                // FINISHED is set after writing last chunk of response
1172                                if inner.flags.contains(Flags::KEEP_ALIVE | Flags::FINISHED) {
1173                                    if let Some(timer) = inner.config.keep_alive_deadline() {
1174                                        inner.as_mut().project().ka_timer.set_and_init(
1175                                            cx,
1176                                            sleep_until(timer.into()),
1177                                            line!(),
1178                                        );
1179                                    }
1180                                }
1181
1182                                false
1183                            }
1184
1185                            // upgrade request and goes Upgrade variant of DispatcherState.
1186                            PollResponse::Upgrade(req) => {
1187                                let upgrade = inner.upgrade(req);
1188                                self.as_mut()
1189                                    .project()
1190                                    .inner
1191                                    .set(DispatcherState::Upgrade { fut: upgrade });
1192                                return self.poll(cx);
1193                            }
1194                        };
1195
1196                        // we didn't get WouldBlock from write operation, so data get written to
1197                        // kernel completely (macOS) and we have to write again otherwise response
1198                        // can get stuck
1199                        //
1200                        // TODO: want to find a reference for this behavior
1201                        // see introduced commit: 3872d3ba
1202                        let flush_was_ready = inner.as_mut().poll_flush(cx)?.is_ready();
1203
1204                        // this assert seems to always be true but not willing to commit to it until
1205                        // we understand what Nikolay meant when writing the above comment
1206                        // debug_assert!(flush_was_ready);
1207
1208                        if !flush_was_ready || !drain {
1209                            break;
1210                        }
1211                    }
1212
1213                    // client is gone
1214                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1215                        trace!("client is gone; disconnecting");
1216                        return Poll::Ready(Ok(()));
1217                    }
1218
1219                    let inner_p = inner.as_mut().project();
1220                    let state_is_none = inner_p.state.is_none();
1221
1222                    // If the read-half is closed, we start the shutdown procedure if either is
1223                    // true:
1224                    //
1225                    // - state is [`State::None`], which means that we're done with request
1226                    //   processing, so if the client closed its writer-side it means that it won't
1227                    //   send more requests.
1228                    // - The user requested to not allow half-closures
1229                    if inner_p.flags.contains(Flags::READ_DISCONNECT)
1230                        && (!inner_p.config.h1_allow_half_closed() || state_is_none)
1231                    {
1232                        trace!("read half closed; start shutdown");
1233                        inner_p.flags.insert(Flags::SHUTDOWN);
1234                    }
1235
1236                    // keep-alive and stream errors
1237                    if state_is_none && inner_p.write_buf.is_empty() {
1238                        if let Some(err) = inner_p.error.take() {
1239                            error!("stream error: {}", &err);
1240                            return Poll::Ready(Err(err));
1241                        }
1242
1243                        // disconnect if keep-alive is not enabled
1244                        if inner_p.flags.contains(Flags::FINISHED)
1245                            && !inner_p.flags.contains(Flags::KEEP_ALIVE)
1246                        {
1247                            inner_p.flags.remove(Flags::FINISHED);
1248                            inner_p.flags.insert(Flags::SHUTDOWN);
1249                            return self.poll(cx);
1250                        }
1251
1252                        // disconnect if shutdown
1253                        if inner_p.flags.contains(Flags::SHUTDOWN) {
1254                            return self.poll(cx);
1255                        }
1256                    }
1257
1258                    trace_timer_states(
1259                        "end",
1260                        inner_p.head_timer,
1261                        inner_p.ka_timer,
1262                        inner_p.shutdown_timer,
1263                    );
1264
1265                    if inner_p.flags.contains(Flags::SHUTDOWN) {
1266                        cx.waker().wake_by_ref();
1267                    }
1268                    Poll::Pending
1269                };
1270
1271                trace!("end flags: {:?}", &inner.flags);
1272
1273                poll
1274            }
1275        }
1276    }
1277}
1278
1279#[allow(dead_code)]
1280fn trace_timer_states(
1281    label: &str,
1282    head_timer: &TimerState,
1283    ka_timer: &TimerState,
1284    shutdown_timer: &TimerState,
1285) {
1286    trace!("{} timers:", label);
1287
1288    if head_timer.is_enabled() {
1289        trace!("  head {}", &head_timer);
1290    }
1291
1292    if ka_timer.is_enabled() {
1293        trace!("  keep-alive {}", &ka_timer);
1294    }
1295
1296    if shutdown_timer.is_enabled() {
1297        trace!("  shutdown {}", &shutdown_timer);
1298    }
1299}