1use std::{
4 future::Future,
5 mem,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use bytes::Bytes;
11use futures_core::ready;
12use pin_project_lite::pin_project;
13
14use crate::{
15 body::EitherBody,
16 dev,
17 web::{Form, Json},
18 Error, FromRequest, HttpRequest, HttpResponse, Responder,
19};
20
21#[derive(Debug, PartialEq, Eq)]
77pub enum Either<L, R> {
78 Left(L),
80
81 Right(R),
83}
84
85impl<T> Either<Form<T>, Json<T>> {
86 pub fn into_inner(self) -> T {
87 match self {
88 Either::Left(form) => form.into_inner(),
89 Either::Right(form) => form.into_inner(),
90 }
91 }
92}
93
94impl<T> Either<Json<T>, Form<T>> {
95 pub fn into_inner(self) -> T {
96 match self {
97 Either::Left(form) => form.into_inner(),
98 Either::Right(form) => form.into_inner(),
99 }
100 }
101}
102
103#[cfg(test)]
104impl<L, R> Either<L, R> {
105 pub(self) fn unwrap_left(self) -> L {
106 match self {
107 Either::Left(data) => data,
108 Either::Right(_) => {
109 panic!("Cannot unwrap Left branch. Either contains an `R` type.")
110 }
111 }
112 }
113
114 pub(self) fn unwrap_right(self) -> R {
115 match self {
116 Either::Left(_) => {
117 panic!("Cannot unwrap Right branch. Either contains an `L` type.")
118 }
119 Either::Right(data) => data,
120 }
121 }
122}
123
124impl<L, R> Responder for Either<L, R>
126where
127 L: Responder,
128 R: Responder,
129{
130 type Body = EitherBody<L::Body, R::Body>;
131
132 fn respond_to(self, req: &HttpRequest) -> HttpResponse<Self::Body> {
133 match self {
134 Either::Left(a) => a.respond_to(req).map_into_left_body(),
135 Either::Right(b) => b.respond_to(req).map_into_right_body(),
136 }
137 }
138}
139
140#[derive(Debug)]
145pub enum EitherExtractError<L, R> {
146 Bytes(Error),
148
149 Extract(L, R),
151}
152
153impl<L, R> From<EitherExtractError<L, R>> for Error
154where
155 L: Into<Error>,
156 R: Into<Error>,
157{
158 fn from(err: EitherExtractError<L, R>) -> Error {
159 match err {
160 EitherExtractError::Bytes(err) => err,
161 EitherExtractError::Extract(a_err, _b_err) => a_err.into(),
162 }
163 }
164}
165
166impl<L, R> FromRequest for Either<L, R>
168where
169 L: FromRequest + 'static,
170 R: FromRequest + 'static,
171{
172 type Error = EitherExtractError<L::Error, R::Error>;
173 type Future = EitherExtractFut<L, R>;
174
175 fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
176 EitherExtractFut {
177 req: req.clone(),
178 state: EitherExtractState::Bytes {
179 bytes: Bytes::from_request(req, payload),
180 },
181 }
182 }
183}
184
185pin_project! {
186 pub struct EitherExtractFut<L, R>
187 where
188 R: FromRequest,
189 L: FromRequest,
190 {
191 req: HttpRequest,
192 #[pin]
193 state: EitherExtractState<L, R>,
194 }
195}
196
197pin_project! {
198 #[project = EitherExtractProj]
199 pub enum EitherExtractState<L, R>
200 where
201 L: FromRequest,
202 R: FromRequest,
203 {
204 Bytes {
205 #[pin]
206 bytes: <Bytes as FromRequest>::Future,
207 },
208 Left {
209 #[pin]
210 left: L::Future,
211 fallback: Bytes,
212 },
213 Right {
214 #[pin]
215 right: R::Future,
216 left_err: Option<L::Error>,
217 },
218 }
219}
220
221impl<R, RF, RE, L, LF, LE> Future for EitherExtractFut<L, R>
222where
223 L: FromRequest<Future = LF, Error = LE>,
224 R: FromRequest<Future = RF, Error = RE>,
225 LF: Future<Output = Result<L, LE>> + 'static,
226 RF: Future<Output = Result<R, RE>> + 'static,
227 LE: Into<Error>,
228 RE: Into<Error>,
229{
230 type Output = Result<Either<L, R>, EitherExtractError<LE, RE>>;
231
232 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 let mut this = self.project();
234 let ready = loop {
235 let next = match this.state.as_mut().project() {
236 EitherExtractProj::Bytes { bytes } => {
237 let res = ready!(bytes.poll(cx));
238 match res {
239 Ok(bytes) => {
240 let fallback = bytes.clone();
241 let left = L::from_request(this.req, &mut payload_from_bytes(bytes));
242 EitherExtractState::Left { left, fallback }
243 }
244 Err(err) => break Err(EitherExtractError::Bytes(err)),
245 }
246 }
247 EitherExtractProj::Left { left, fallback } => {
248 let res = ready!(left.poll(cx));
249 match res {
250 Ok(extracted) => break Ok(Either::Left(extracted)),
251 Err(left_err) => {
252 let right = R::from_request(
253 this.req,
254 &mut payload_from_bytes(mem::take(fallback)),
255 );
256 EitherExtractState::Right {
257 left_err: Some(left_err),
258 right,
259 }
260 }
261 }
262 }
263 EitherExtractProj::Right { right, left_err } => {
264 let res = ready!(right.poll(cx));
265 match res {
266 Ok(data) => break Ok(Either::Right(data)),
267 Err(err) => {
268 break Err(EitherExtractError::Extract(left_err.take().unwrap(), err));
269 }
270 }
271 }
272 };
273 this.state.set(next);
274 };
275 Poll::Ready(ready)
276 }
277}
278
279fn payload_from_bytes(bytes: Bytes) -> dev::Payload {
280 let (_, mut h1_payload) = actix_http::h1::Payload::create(true);
281 h1_payload.unread_data(bytes);
282 dev::Payload::from(h1_payload)
283}
284
285#[cfg(test)]
286mod tests {
287 use serde::{Deserialize, Serialize};
288
289 use super::*;
290 use crate::test::TestRequest;
291
292 #[derive(Debug, Clone, Serialize, Deserialize)]
293 struct TestForm {
294 hello: String,
295 }
296
297 #[actix_rt::test]
298 async fn test_either_extract_first_try() {
299 let (req, mut pl) = TestRequest::default()
300 .set_form(TestForm {
301 hello: "world".to_owned(),
302 })
303 .to_http_parts();
304
305 let form = Either::<Form<TestForm>, Json<TestForm>>::from_request(&req, &mut pl)
306 .await
307 .unwrap()
308 .unwrap_left()
309 .into_inner();
310 assert_eq!(&form.hello, "world");
311 }
312
313 #[actix_rt::test]
314 async fn test_either_extract_fallback() {
315 let (req, mut pl) = TestRequest::default()
316 .set_json(TestForm {
317 hello: "world".to_owned(),
318 })
319 .to_http_parts();
320
321 let form = Either::<Form<TestForm>, Json<TestForm>>::from_request(&req, &mut pl)
322 .await
323 .unwrap()
324 .unwrap_right()
325 .into_inner();
326 assert_eq!(&form.hello, "world");
327 }
328
329 #[actix_rt::test]
330 async fn test_either_extract_recursive_fallback() {
331 let (req, mut pl) = TestRequest::default()
332 .set_payload(Bytes::from_static(b"!@$%^&*()"))
333 .to_http_parts();
334
335 let payload =
336 Either::<Either<Form<TestForm>, Json<TestForm>>, Bytes>::from_request(&req, &mut pl)
337 .await
338 .unwrap()
339 .unwrap_right();
340 assert_eq!(&payload.as_ref(), &b"!@$%^&*()");
341 }
342
343 #[actix_rt::test]
344 async fn test_either_extract_recursive_fallback_inner() {
345 let (req, mut pl) = TestRequest::default()
346 .set_json(TestForm {
347 hello: "world".to_owned(),
348 })
349 .to_http_parts();
350
351 let form =
352 Either::<Either<Form<TestForm>, Json<TestForm>>, Bytes>::from_request(&req, &mut pl)
353 .await
354 .unwrap()
355 .unwrap_left()
356 .unwrap_right()
357 .into_inner();
358 assert_eq!(&form.hello, "world");
359 }
360}