actix_http/ws/
frame.rs

1use std::cmp::min;
2
3use bytes::{Buf, BufMut, BytesMut};
4use tracing::debug;
5
6use super::{
7    mask::apply_mask,
8    proto::{CloseCode, CloseReason, OpCode},
9    ProtocolError,
10};
11
12/// A struct representing a WebSocket frame.
13#[derive(Debug)]
14pub struct Parser;
15
16impl Parser {
17    fn parse_metadata(
18        src: &[u8],
19        server: bool,
20    ) -> Result<Option<(usize, bool, OpCode, usize, Option<[u8; 4]>)>, ProtocolError> {
21        let chunk_len = src.len();
22
23        let mut idx = 2;
24        if chunk_len < 2 {
25            return Ok(None);
26        }
27
28        let first = src[0];
29        let second = src[1];
30        let finished = first & 0x80 != 0;
31
32        // check masking
33        let masked = second & 0x80 != 0;
34        if !masked && server {
35            return Err(ProtocolError::UnmaskedFrame);
36        } else if masked && !server {
37            return Err(ProtocolError::MaskedFrame);
38        }
39
40        // Op code
41        let opcode = OpCode::from(first & 0x0F);
42
43        if let OpCode::Bad = opcode {
44            return Err(ProtocolError::InvalidOpcode(first & 0x0F));
45        }
46
47        let len = second & 0x7F;
48        let length = if len == 126 {
49            if chunk_len < 4 {
50                return Ok(None);
51            }
52            let len = usize::from(u16::from_be_bytes(
53                TryFrom::try_from(&src[idx..idx + 2]).unwrap(),
54            ));
55            idx += 2;
56            len
57        } else if len == 127 {
58            if chunk_len < 10 {
59                return Ok(None);
60            }
61            let len = u64::from_be_bytes(TryFrom::try_from(&src[idx..idx + 8]).unwrap());
62            idx += 8;
63            len as usize
64        } else {
65            len as usize
66        };
67
68        let mask = if server {
69            if chunk_len < idx + 4 {
70                return Ok(None);
71            }
72
73            let mask = TryFrom::try_from(&src[idx..idx + 4]).unwrap();
74
75            idx += 4;
76
77            Some(mask)
78        } else {
79            None
80        };
81
82        Ok(Some((idx, finished, opcode, length, mask)))
83    }
84
85    /// Parse the input stream into a frame.
86    pub fn parse(
87        src: &mut BytesMut,
88        server: bool,
89        max_size: usize,
90    ) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
91        // try to parse ws frame metadata
92        let (idx, finished, opcode, length, mask) = match Parser::parse_metadata(src, server)? {
93            None => return Ok(None),
94            Some(res) => res,
95        };
96
97        let frame_len = match idx.checked_add(length) {
98            Some(len) => len,
99            None => return Err(ProtocolError::Overflow),
100        };
101
102        // not enough data
103        if src.len() < frame_len {
104            let min_length = min(length, max_size);
105            let required_cap = match idx.checked_add(min_length) {
106                Some(cap) => cap,
107                None => return Err(ProtocolError::Overflow),
108            };
109
110            if src.capacity() < required_cap {
111                src.reserve(required_cap - src.capacity());
112            }
113            return Ok(None);
114        }
115
116        // remove prefix
117        src.advance(idx);
118
119        // check for max allowed size
120        if length > max_size {
121            // drop the payload
122            src.advance(length);
123            return Err(ProtocolError::Overflow);
124        }
125
126        // no need for body
127        if length == 0 {
128            return Ok(Some((finished, opcode, None)));
129        }
130
131        let mut data = src.split_to(length);
132
133        // control frames must have length <= 125
134        match opcode {
135            OpCode::Ping | OpCode::Pong if length > 125 => {
136                return Err(ProtocolError::InvalidLength(length));
137            }
138            OpCode::Close if length > 125 => {
139                debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
140                return Ok(Some((true, OpCode::Close, None)));
141            }
142            _ => {}
143        }
144
145        // unmask
146        if let Some(mask) = mask {
147            apply_mask(&mut data, mask);
148        }
149
150        Ok(Some((finished, opcode, Some(data))))
151    }
152
153    /// Parse the payload of a close frame.
154    pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
155        if payload.len() >= 2 {
156            let raw_code = u16::from_be_bytes(TryFrom::try_from(&payload[..2]).unwrap());
157            let code = CloseCode::from(raw_code);
158            let description = if payload.len() > 2 {
159                Some(String::from_utf8_lossy(&payload[2..]).into())
160            } else {
161                None
162            };
163            Some(CloseReason { code, description })
164        } else {
165            None
166        }
167    }
168
169    /// Generate binary representation
170    pub fn write_message<B: AsRef<[u8]>>(
171        dst: &mut BytesMut,
172        pl: B,
173        op: OpCode,
174        fin: bool,
175        mask: bool,
176    ) {
177        let payload = pl.as_ref();
178        let one: u8 = if fin {
179            0x80 | Into::<u8>::into(op)
180        } else {
181            op.into()
182        };
183        let payload_len = payload.len();
184        let (two, p_len) = if mask {
185            (0x80, payload_len + 4)
186        } else {
187            (0, payload_len)
188        };
189
190        if payload_len < 126 {
191            dst.reserve(p_len + 2);
192            dst.put_slice(&[one, two | payload_len as u8]);
193        } else if payload_len <= 65_535 {
194            dst.reserve(p_len + 4);
195            dst.put_slice(&[one, two | 126]);
196            dst.put_u16(payload_len as u16);
197        } else {
198            dst.reserve(p_len + 10);
199            dst.put_slice(&[one, two | 127]);
200            dst.put_u64(payload_len as u64);
201        };
202
203        if mask {
204            let mask = rand::random::<[u8; 4]>();
205            dst.put_slice(mask.as_ref());
206            dst.put_slice(payload.as_ref());
207            let pos = dst.len() - payload_len;
208            apply_mask(&mut dst[pos..], mask);
209        } else {
210            dst.put_slice(payload.as_ref());
211        }
212    }
213
214    /// Create a new Close control frame.
215    #[inline]
216    pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
217        let payload = match reason {
218            None => Vec::new(),
219            Some(reason) => {
220                let mut payload = Into::<u16>::into(reason.code).to_be_bytes().to_vec();
221                if let Some(description) = reason.description {
222                    payload.extend(description.as_bytes());
223                }
224                payload
225            }
226        };
227
228        Parser::write_message(dst, payload, OpCode::Close, true, mask)
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use bytes::Bytes;
235
236    use super::*;
237
238    struct F {
239        finished: bool,
240        opcode: OpCode,
241        payload: Bytes,
242    }
243
244    fn is_none(frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>) -> bool {
245        matches!(*frm, Ok(None))
246    }
247
248    fn extract(frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>) -> F {
249        match frm {
250            Ok(Some((finished, opcode, payload))) => F {
251                finished,
252                opcode,
253                payload: payload
254                    .map(|b| b.freeze())
255                    .unwrap_or_else(|| Bytes::from("")),
256            },
257            _ => unreachable!("error"),
258        }
259    }
260
261    #[test]
262    fn test_parse() {
263        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
264        assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
265
266        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
267        buf.extend(b"1");
268
269        let frame = extract(Parser::parse(&mut buf, false, 1024));
270        assert!(!frame.finished);
271        assert_eq!(frame.opcode, OpCode::Text);
272        assert_eq!(frame.payload.as_ref(), &b"1"[..]);
273    }
274
275    #[test]
276    fn test_parse_length0() {
277        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
278        let frame = extract(Parser::parse(&mut buf, false, 1024));
279        assert!(!frame.finished);
280        assert_eq!(frame.opcode, OpCode::Text);
281        assert!(frame.payload.is_empty());
282    }
283
284    #[test]
285    fn test_parse_length2() {
286        let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
287        assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
288
289        let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
290        buf.extend(&[0u8, 4u8][..]);
291        buf.extend(b"1234");
292
293        let frame = extract(Parser::parse(&mut buf, false, 1024));
294        assert!(!frame.finished);
295        assert_eq!(frame.opcode, OpCode::Text);
296        assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
297    }
298
299    #[test]
300    fn test_parse_length4() {
301        let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
302        assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
303
304        let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
305        buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
306        buf.extend(b"1234");
307
308        let frame = extract(Parser::parse(&mut buf, false, 1024));
309        assert!(!frame.finished);
310        assert_eq!(frame.opcode, OpCode::Text);
311        assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
312    }
313
314    #[test]
315    fn test_parse_frame_mask() {
316        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
317        buf.extend(b"0001");
318        buf.extend(b"1");
319
320        assert!(Parser::parse(&mut buf, false, 1024).is_err());
321
322        let frame = extract(Parser::parse(&mut buf, true, 1024));
323        assert!(!frame.finished);
324        assert_eq!(frame.opcode, OpCode::Text);
325        assert_eq!(frame.payload, Bytes::from(vec![1u8]));
326    }
327
328    #[test]
329    fn test_parse_frame_no_mask() {
330        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
331        buf.extend([1u8]);
332
333        assert!(Parser::parse(&mut buf, true, 1024).is_err());
334
335        let frame = extract(Parser::parse(&mut buf, false, 1024));
336        assert!(!frame.finished);
337        assert_eq!(frame.opcode, OpCode::Text);
338        assert_eq!(frame.payload, Bytes::from(vec![1u8]));
339    }
340
341    #[test]
342    fn test_parse_frame_max_size() {
343        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
344        buf.extend([1u8, 1u8]);
345
346        assert!(Parser::parse(&mut buf, true, 1).is_err());
347
348        if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
349        } else {
350            unreachable!("error");
351        }
352    }
353
354    #[test]
355    fn test_parse_frame_max_size_recoverability() {
356        let mut buf = BytesMut::new();
357        // The first text frame with length == 2, payload doesn't matter.
358        buf.extend([0b0000_0001u8, 0b0000_0010u8, 0b0000_0000u8, 0b0000_0000u8]);
359        // Next binary frame with length == 2 and payload == `[0x1111_1111u8, 0x1111_1111u8]`.
360        buf.extend([0b0000_0010u8, 0b0000_0010u8, 0b1111_1111u8, 0b1111_1111u8]);
361
362        assert_eq!(buf.len(), 8);
363        assert!(matches!(
364            Parser::parse(&mut buf, false, 1),
365            Err(ProtocolError::Overflow)
366        ));
367        assert_eq!(buf.len(), 4);
368        let frame = extract(Parser::parse(&mut buf, false, 2));
369        assert!(!frame.finished);
370        assert_eq!(frame.opcode, OpCode::Binary);
371        assert_eq!(
372            frame.payload,
373            Bytes::from(vec![0b1111_1111u8, 0b1111_1111u8])
374        );
375        assert_eq!(buf.len(), 0);
376    }
377
378    #[test]
379    fn test_ping_frame() {
380        let mut buf = BytesMut::new();
381        Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
382
383        let mut v = vec![137u8, 4u8];
384        v.extend(b"data");
385        assert_eq!(&buf[..], &v[..]);
386    }
387
388    #[test]
389    fn test_pong_frame() {
390        let mut buf = BytesMut::new();
391        Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
392
393        let mut v = vec![138u8, 4u8];
394        v.extend(b"data");
395        assert_eq!(&buf[..], &v[..]);
396    }
397
398    #[test]
399    fn test_close_frame() {
400        let mut buf = BytesMut::new();
401        let reason = (CloseCode::Normal, "data");
402        Parser::write_close(&mut buf, Some(reason.into()), false);
403
404        let mut v = vec![136u8, 6u8, 3u8, 232u8];
405        v.extend(b"data");
406        assert_eq!(&buf[..], &v[..]);
407    }
408
409    #[test]
410    fn test_empty_close_frame() {
411        let mut buf = BytesMut::new();
412        Parser::write_close(&mut buf, None, false);
413        assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
414    }
415
416    #[test]
417    fn test_parse_length_overflow() {
418        let buf: [u8; 14] = [
419            0x0a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xeb, 0x0e, 0x8f,
420        ];
421        let mut buf = BytesMut::from(&buf[..]);
422        let result = Parser::parse(&mut buf, true, 65536);
423        assert!(matches!(result, Err(ProtocolError::Overflow)));
424    }
425}