1use std::cmp::min;
2
3use bytes::{Buf, BufMut, BytesMut};
4use tracing::debug;
5
6use super::{
7 mask::apply_mask,
8 proto::{CloseCode, CloseReason, OpCode},
9 ProtocolError,
10};
11
12#[derive(Debug)]
14pub struct Parser;
15
16impl Parser {
17 fn parse_metadata(
18 src: &[u8],
19 server: bool,
20 ) -> Result<Option<(usize, bool, OpCode, usize, Option<[u8; 4]>)>, ProtocolError> {
21 let chunk_len = src.len();
22
23 let mut idx = 2;
24 if chunk_len < 2 {
25 return Ok(None);
26 }
27
28 let first = src[0];
29 let second = src[1];
30 let finished = first & 0x80 != 0;
31
32 let masked = second & 0x80 != 0;
34 if !masked && server {
35 return Err(ProtocolError::UnmaskedFrame);
36 } else if masked && !server {
37 return Err(ProtocolError::MaskedFrame);
38 }
39
40 let opcode = OpCode::from(first & 0x0F);
42
43 if let OpCode::Bad = opcode {
44 return Err(ProtocolError::InvalidOpcode(first & 0x0F));
45 }
46
47 let len = second & 0x7F;
48 let length = if len == 126 {
49 if chunk_len < 4 {
50 return Ok(None);
51 }
52 let len = usize::from(u16::from_be_bytes(
53 TryFrom::try_from(&src[idx..idx + 2]).unwrap(),
54 ));
55 idx += 2;
56 len
57 } else if len == 127 {
58 if chunk_len < 10 {
59 return Ok(None);
60 }
61 let len = u64::from_be_bytes(TryFrom::try_from(&src[idx..idx + 8]).unwrap());
62 idx += 8;
63 len as usize
64 } else {
65 len as usize
66 };
67
68 let mask = if server {
69 if chunk_len < idx + 4 {
70 return Ok(None);
71 }
72
73 let mask = TryFrom::try_from(&src[idx..idx + 4]).unwrap();
74
75 idx += 4;
76
77 Some(mask)
78 } else {
79 None
80 };
81
82 Ok(Some((idx, finished, opcode, length, mask)))
83 }
84
85 pub fn parse(
87 src: &mut BytesMut,
88 server: bool,
89 max_size: usize,
90 ) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
91 let (idx, finished, opcode, length, mask) = match Parser::parse_metadata(src, server)? {
93 None => return Ok(None),
94 Some(res) => res,
95 };
96
97 let frame_len = match idx.checked_add(length) {
98 Some(len) => len,
99 None => return Err(ProtocolError::Overflow),
100 };
101
102 if src.len() < frame_len {
104 let min_length = min(length, max_size);
105 let required_cap = match idx.checked_add(min_length) {
106 Some(cap) => cap,
107 None => return Err(ProtocolError::Overflow),
108 };
109
110 if src.capacity() < required_cap {
111 src.reserve(required_cap - src.capacity());
112 }
113 return Ok(None);
114 }
115
116 src.advance(idx);
118
119 if length > max_size {
121 src.advance(length);
123 return Err(ProtocolError::Overflow);
124 }
125
126 if length == 0 {
128 return Ok(Some((finished, opcode, None)));
129 }
130
131 let mut data = src.split_to(length);
132
133 match opcode {
135 OpCode::Ping | OpCode::Pong if length > 125 => {
136 return Err(ProtocolError::InvalidLength(length));
137 }
138 OpCode::Close if length > 125 => {
139 debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
140 return Ok(Some((true, OpCode::Close, None)));
141 }
142 _ => {}
143 }
144
145 if let Some(mask) = mask {
147 apply_mask(&mut data, mask);
148 }
149
150 Ok(Some((finished, opcode, Some(data))))
151 }
152
153 pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
155 if payload.len() >= 2 {
156 let raw_code = u16::from_be_bytes(TryFrom::try_from(&payload[..2]).unwrap());
157 let code = CloseCode::from(raw_code);
158 let description = if payload.len() > 2 {
159 Some(String::from_utf8_lossy(&payload[2..]).into())
160 } else {
161 None
162 };
163 Some(CloseReason { code, description })
164 } else {
165 None
166 }
167 }
168
169 pub fn write_message<B: AsRef<[u8]>>(
171 dst: &mut BytesMut,
172 pl: B,
173 op: OpCode,
174 fin: bool,
175 mask: bool,
176 ) {
177 let payload = pl.as_ref();
178 let one: u8 = if fin {
179 0x80 | Into::<u8>::into(op)
180 } else {
181 op.into()
182 };
183 let payload_len = payload.len();
184 let (two, p_len) = if mask {
185 (0x80, payload_len + 4)
186 } else {
187 (0, payload_len)
188 };
189
190 if payload_len < 126 {
191 dst.reserve(p_len + 2);
192 dst.put_slice(&[one, two | payload_len as u8]);
193 } else if payload_len <= 65_535 {
194 dst.reserve(p_len + 4);
195 dst.put_slice(&[one, two | 126]);
196 dst.put_u16(payload_len as u16);
197 } else {
198 dst.reserve(p_len + 10);
199 dst.put_slice(&[one, two | 127]);
200 dst.put_u64(payload_len as u64);
201 };
202
203 if mask {
204 let mask = rand::random::<[u8; 4]>();
205 dst.put_slice(mask.as_ref());
206 dst.put_slice(payload.as_ref());
207 let pos = dst.len() - payload_len;
208 apply_mask(&mut dst[pos..], mask);
209 } else {
210 dst.put_slice(payload.as_ref());
211 }
212 }
213
214 #[inline]
216 pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
217 let payload = match reason {
218 None => Vec::new(),
219 Some(reason) => {
220 let mut payload = Into::<u16>::into(reason.code).to_be_bytes().to_vec();
221 if let Some(description) = reason.description {
222 payload.extend(description.as_bytes());
223 }
224 payload
225 }
226 };
227
228 Parser::write_message(dst, payload, OpCode::Close, true, mask)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use bytes::Bytes;
235
236 use super::*;
237
238 struct F {
239 finished: bool,
240 opcode: OpCode,
241 payload: Bytes,
242 }
243
244 fn is_none(frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>) -> bool {
245 matches!(*frm, Ok(None))
246 }
247
248 fn extract(frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>) -> F {
249 match frm {
250 Ok(Some((finished, opcode, payload))) => F {
251 finished,
252 opcode,
253 payload: payload
254 .map(|b| b.freeze())
255 .unwrap_or_else(|| Bytes::from("")),
256 },
257 _ => unreachable!("error"),
258 }
259 }
260
261 #[test]
262 fn test_parse() {
263 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
264 assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
265
266 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
267 buf.extend(b"1");
268
269 let frame = extract(Parser::parse(&mut buf, false, 1024));
270 assert!(!frame.finished);
271 assert_eq!(frame.opcode, OpCode::Text);
272 assert_eq!(frame.payload.as_ref(), &b"1"[..]);
273 }
274
275 #[test]
276 fn test_parse_length0() {
277 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
278 let frame = extract(Parser::parse(&mut buf, false, 1024));
279 assert!(!frame.finished);
280 assert_eq!(frame.opcode, OpCode::Text);
281 assert!(frame.payload.is_empty());
282 }
283
284 #[test]
285 fn test_parse_length2() {
286 let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
287 assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
288
289 let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
290 buf.extend(&[0u8, 4u8][..]);
291 buf.extend(b"1234");
292
293 let frame = extract(Parser::parse(&mut buf, false, 1024));
294 assert!(!frame.finished);
295 assert_eq!(frame.opcode, OpCode::Text);
296 assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
297 }
298
299 #[test]
300 fn test_parse_length4() {
301 let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
302 assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
303
304 let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
305 buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
306 buf.extend(b"1234");
307
308 let frame = extract(Parser::parse(&mut buf, false, 1024));
309 assert!(!frame.finished);
310 assert_eq!(frame.opcode, OpCode::Text);
311 assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
312 }
313
314 #[test]
315 fn test_parse_frame_mask() {
316 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
317 buf.extend(b"0001");
318 buf.extend(b"1");
319
320 assert!(Parser::parse(&mut buf, false, 1024).is_err());
321
322 let frame = extract(Parser::parse(&mut buf, true, 1024));
323 assert!(!frame.finished);
324 assert_eq!(frame.opcode, OpCode::Text);
325 assert_eq!(frame.payload, Bytes::from(vec![1u8]));
326 }
327
328 #[test]
329 fn test_parse_frame_no_mask() {
330 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
331 buf.extend([1u8]);
332
333 assert!(Parser::parse(&mut buf, true, 1024).is_err());
334
335 let frame = extract(Parser::parse(&mut buf, false, 1024));
336 assert!(!frame.finished);
337 assert_eq!(frame.opcode, OpCode::Text);
338 assert_eq!(frame.payload, Bytes::from(vec![1u8]));
339 }
340
341 #[test]
342 fn test_parse_frame_max_size() {
343 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
344 buf.extend([1u8, 1u8]);
345
346 assert!(Parser::parse(&mut buf, true, 1).is_err());
347
348 if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
349 } else {
350 unreachable!("error");
351 }
352 }
353
354 #[test]
355 fn test_parse_frame_max_size_recoverability() {
356 let mut buf = BytesMut::new();
357 buf.extend([0b0000_0001u8, 0b0000_0010u8, 0b0000_0000u8, 0b0000_0000u8]);
359 buf.extend([0b0000_0010u8, 0b0000_0010u8, 0b1111_1111u8, 0b1111_1111u8]);
361
362 assert_eq!(buf.len(), 8);
363 assert!(matches!(
364 Parser::parse(&mut buf, false, 1),
365 Err(ProtocolError::Overflow)
366 ));
367 assert_eq!(buf.len(), 4);
368 let frame = extract(Parser::parse(&mut buf, false, 2));
369 assert!(!frame.finished);
370 assert_eq!(frame.opcode, OpCode::Binary);
371 assert_eq!(
372 frame.payload,
373 Bytes::from(vec![0b1111_1111u8, 0b1111_1111u8])
374 );
375 assert_eq!(buf.len(), 0);
376 }
377
378 #[test]
379 fn test_ping_frame() {
380 let mut buf = BytesMut::new();
381 Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
382
383 let mut v = vec![137u8, 4u8];
384 v.extend(b"data");
385 assert_eq!(&buf[..], &v[..]);
386 }
387
388 #[test]
389 fn test_pong_frame() {
390 let mut buf = BytesMut::new();
391 Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
392
393 let mut v = vec![138u8, 4u8];
394 v.extend(b"data");
395 assert_eq!(&buf[..], &v[..]);
396 }
397
398 #[test]
399 fn test_close_frame() {
400 let mut buf = BytesMut::new();
401 let reason = (CloseCode::Normal, "data");
402 Parser::write_close(&mut buf, Some(reason.into()), false);
403
404 let mut v = vec![136u8, 6u8, 3u8, 232u8];
405 v.extend(b"data");
406 assert_eq!(&buf[..], &v[..]);
407 }
408
409 #[test]
410 fn test_empty_close_frame() {
411 let mut buf = BytesMut::new();
412 Parser::write_close(&mut buf, None, false);
413 assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
414 }
415
416 #[test]
417 fn test_parse_length_overflow() {
418 let buf: [u8; 14] = [
419 0x0a, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xeb, 0x0e, 0x8f,
420 ];
421 let mut buf = BytesMut::from(&buf[..]);
422 let result = Parser::parse(&mut buf, true, 65536);
423 assert!(matches!(result, Err(ProtocolError::Overflow)));
424 }
425}