1use std::{convert::Infallible, str};
6
7use actix_web::{
8 error::ParseError,
9 http::header::{self, Header, HeaderName, HeaderValue, TryIntoHeaderValue},
10 HttpMessage,
11};
12use itertools::Itertools as _;
13
14#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
24#[cfg_attr(test, derive(Default))]
25pub struct Forwarded {
26 by: Option<String>,
35
36 r#for: Vec<String>,
39
40 host: Option<String>,
42
43 proto: Option<String>,
45}
46
47impl Forwarded {
48 pub fn new(
50 by: impl Into<Option<String>>,
51 r#for: impl Into<Vec<String>>,
52 host: impl Into<Option<String>>,
53 proto: impl Into<Option<String>>,
54 ) -> Self {
55 Self {
56 by: by.into(),
57 r#for: r#for.into(),
58 host: host.into(),
59 proto: proto.into(),
60 }
61 }
62
63 pub fn new_for(r#for: impl Into<String>) -> Self {
65 Self {
66 by: None,
67 r#for: vec![r#for.into()],
68 host: None,
69 proto: None,
70 }
71 }
72
73 pub fn for_client(&self) -> Option<&str> {
75 self.r#for.first().map(String::as_str)
86 }
87
88 pub fn for_chain(&self) -> impl Iterator<Item = &'_ str> {
93 self.r#for.iter().map(|r#for| r#for.as_str())
94 }
95
96 pub fn by(&self) -> Option<&str> {
100 self.by.as_deref()
101 }
102
103 pub fn host(&self) -> Option<&str> {
107 self.host.as_deref()
108 }
109
110 pub fn proto(&self) -> Option<&str> {
114 self.proto.as_deref()
115 }
116
117 pub fn push_for(&mut self, identifier: impl Into<String>) {
122 self.r#for.push(identifier.into())
123 }
124
125 fn has_no_info(&self) -> bool {
127 self.by.is_none() && self.r#for.is_empty() && self.host.is_none() && self.proto.is_none()
128 }
129
130 }
132
133impl str::FromStr for Forwarded {
134 type Err = Infallible;
135
136 #[inline]
137 fn from_str(val: &str) -> Result<Self, Self::Err> {
138 let mut by = None;
139 let mut host = None;
140 let mut proto = None;
141 let mut r#for = vec![];
142
143 for (name, val) in val
145 .split(';')
146 .flat_map(|vals| vals.split(','))
148 .flat_map(|pair| {
150 let mut items = pair.trim().splitn(2, '=');
151 Some((items.next()?, items.next()?))
152 })
153 {
154 match name.trim().to_lowercase().as_str() {
158 "by" => {
159 by.get_or_insert_with(|| unquote(val));
161 }
162 "for" => {
163 r#for.push(unquote(val));
165 }
166 "host" => {
167 host.get_or_insert_with(|| unquote(val));
169 }
170 "proto" => {
171 proto.get_or_insert_with(|| unquote(val));
173 }
174 _ => continue,
175 };
176 }
177
178 Ok(Self {
179 by: by.map(str::to_owned),
180 r#for: r#for.into_iter().map(str::to_owned).collect(),
181 host: host.map(str::to_owned),
182 proto: proto.map(str::to_owned),
183 })
184 }
185}
186
187impl TryIntoHeaderValue for Forwarded {
188 type Error = header::InvalidHeaderValue;
189
190 fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
191 if self.has_no_info() {
192 return Ok(HeaderValue::from_static(""));
193 }
194
195 let r#for = if self.r#for.is_empty() {
196 None
197 } else {
198 let value = self
199 .r#for
200 .into_iter()
201 .map(|ident| format!("for=\"{ident}\""))
202 .join(", ");
203
204 Some(value)
205 };
206
207 self.by
211 .map(|by| format!("by=\"{by}\""))
212 .into_iter()
213 .chain(r#for)
214 .chain(self.host.map(|host| format!("host=\"{host}\"")))
215 .chain(self.proto.map(|proto| format!("proto=\"{proto}\"")))
216 .join("; ")
217 .try_into_value()
218 }
219}
220
221impl Header for Forwarded {
222 fn name() -> HeaderName {
223 header::FORWARDED
224 }
225
226 fn parse<M: HttpMessage>(msg: &M) -> Result<Self, ParseError> {
227 let combined = msg
228 .headers()
229 .get_all(Self::name())
230 .filter_map(|hdr| hdr.to_str().ok())
231 .filter_map(|hdr_str| match hdr_str.trim() {
232 "" => None,
233 val => Some(val),
234 })
235 .collect::<Vec<_>>();
236
237 if combined.is_empty() {
238 return Err(ParseError::Header);
239 }
240
241 combined.join(";").parse().map_err(|_| ParseError::Header)
244 }
245}
246
247fn unquote(val: &str) -> &str {
249 val.trim().trim_start_matches('"').trim_end_matches('"')
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::header::{assert_parse_eq, assert_parse_fail};
256
257 #[test]
258 fn missing_header() {
259 assert_parse_fail::<Forwarded, _, _>([""; 0]);
260 assert_parse_fail::<Forwarded, _, _>([""]);
261 }
262
263 #[test]
264 fn parsing_header_parts() {
265 assert_parse_eq::<Forwarded, _, _>([";"], Forwarded::default());
266
267 assert_parse_eq::<Forwarded, _, _>(
268 ["for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org"],
269 Forwarded {
270 host: Some("rust-lang.org".to_owned()),
271 proto: Some("https".to_owned()),
272 r#for: vec!["192.0.2.60".to_owned()],
273 by: Some("203.0.113.43".to_owned()),
274 },
275 );
276
277 assert_parse_eq::<Forwarded, _, _>(
278 ["for=192.0.2.60; proto=https", "host=rust-lang.org"],
279 Forwarded {
280 by: None,
281 host: Some("rust-lang.org".to_owned()),
282 r#for: vec!["192.0.2.60".to_owned()],
283 proto: Some("https".to_owned()),
284 },
285 );
286 }
287
288 #[test]
289 fn serializing() {
290 let fwd = Forwarded {
291 by: Some("203.0.113.43".to_owned()),
292 r#for: vec!["192.0.2.60".to_owned()],
293 host: Some("rust-lang.org".to_owned()),
294 proto: Some("https".to_owned()),
295 };
296
297 assert_eq!(
298 fwd.try_into_value().unwrap(),
299 r#"by="203.0.113.43"; for="192.0.2.60"; host="rust-lang.org"; proto="https""#
300 );
301 }
302
303 #[test]
304 fn case_sensitivity() {
305 assert_parse_eq::<Forwarded, _, _>(
306 ["For=192.0.2.60"],
307 Forwarded {
308 r#for: vec!["192.0.2.60".to_owned()],
309 ..Forwarded::default()
310 },
311 );
312 }
313
314 #[test]
315 fn weird_whitespace() {
316 assert_parse_eq::<Forwarded, _, _>(
317 ["for= 1.2.3.4; proto= https"],
318 Forwarded {
319 r#for: vec!["1.2.3.4".to_owned()],
320 proto: Some("https".to_owned()),
321 ..Forwarded::default()
322 },
323 );
324
325 assert_parse_eq::<Forwarded, _, _>(
326 [" for = 1.2.3.4 "],
327 Forwarded {
328 r#for: vec!["1.2.3.4".to_owned()],
329 ..Forwarded::default()
330 },
331 );
332 }
333
334 #[test]
335 fn for_quoted() {
336 assert_parse_eq::<Forwarded, _, _>(
337 [r#"for="192.0.2.60:8080""#],
338 Forwarded {
339 r#for: vec!["192.0.2.60:8080".to_owned()],
340 ..Forwarded::default()
341 },
342 );
343 }
344
345 #[test]
346 fn for_ipv6() {
347 assert_parse_eq::<Forwarded, _, _>(
348 [r#"for="[2001:db8:cafe::17]:4711""#],
349 Forwarded {
350 r#for: vec!["[2001:db8:cafe::17]:4711".to_owned()],
351 ..Forwarded::default()
352 },
353 );
354 }
355
356 #[test]
357 fn for_multiple() {
358 let fwd = Forwarded {
359 r#for: vec!["192.0.2.60".to_owned(), "198.51.100.17".to_owned()],
360 ..Forwarded::default()
361 };
362
363 assert_eq!(fwd.for_client().unwrap(), "192.0.2.60");
364
365 assert_parse_eq::<Forwarded, _, _>(["for=192.0.2.60, for=198.51.100.17"], fwd);
366 }
367}