1use std::io;
7
8use derive_more::{Display, Error, From};
9use http::{header, Method, StatusCode};
10
11use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder};
12
13mod codec;
14mod dispatcher;
15mod frame;
16mod mask;
17mod proto;
18
19pub use self::{
20    codec::{Codec, Frame, Item, Message},
21    dispatcher::Dispatcher,
22    frame::Parser,
23    proto::{hash_key, CloseCode, CloseReason, OpCode},
24};
25
26#[derive(Debug, Display, Error, From)]
28pub enum ProtocolError {
29    #[display("received an unmasked frame from client")]
31    UnmaskedFrame,
32
33    #[display("received a masked frame from server")]
35    MaskedFrame,
36
37    #[display("invalid opcode ({})", _0)]
39    InvalidOpcode(#[error(not(source))] u8),
40
41    #[display("invalid control frame length ({})", _0)]
43    InvalidLength(#[error(not(source))] usize),
44
45    #[display("bad opcode")]
47    BadOpCode,
48
49    #[display("payload reached size limit")]
51    Overflow,
52
53    #[display("continuation has not started")]
55    ContinuationNotStarted,
56
57    #[display("received new continuation but it has already started")]
59    ContinuationStarted,
60
61    #[display("unknown continuation fragment: {}", _0)]
63    ContinuationFragment(#[error(not(source))] OpCode),
64
65    #[display("I/O error: {}", _0)]
67    Io(io::Error),
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, Error)]
72pub enum HandshakeError {
73    #[display("method not allowed")]
75    GetMethodRequired,
76
77    #[display("WebSocket upgrade is expected")]
79    NoWebsocketUpgrade,
80
81    #[display("connection upgrade is expected")]
83    NoConnectionUpgrade,
84
85    #[display("WebSocket version header is required")]
87    NoVersionHeader,
88
89    #[display("unsupported WebSocket version")]
91    UnsupportedVersion,
92
93    #[display("unknown WebSocket key")]
95    BadWebsocketKey,
96}
97
98impl From<HandshakeError> for Response<BoxBody> {
99    fn from(err: HandshakeError) -> Self {
100        match err {
101            HandshakeError::GetMethodRequired => {
102                let mut res = Response::new(StatusCode::METHOD_NOT_ALLOWED);
103                #[allow(clippy::declare_interior_mutable_const)]
104                const HV_GET: HeaderValue = HeaderValue::from_static("GET");
105                res.headers_mut().insert(header::ALLOW, HV_GET);
106                res
107            }
108
109            HandshakeError::NoWebsocketUpgrade => {
110                let mut res = Response::bad_request();
111                res.head_mut().reason = Some("No WebSocket Upgrade header found");
112                res
113            }
114
115            HandshakeError::NoConnectionUpgrade => {
116                let mut res = Response::bad_request();
117                res.head_mut().reason = Some("No Connection upgrade");
118                res
119            }
120
121            HandshakeError::NoVersionHeader => {
122                let mut res = Response::bad_request();
123                res.head_mut().reason = Some("WebSocket version header is required");
124                res
125            }
126
127            HandshakeError::UnsupportedVersion => {
128                let mut res = Response::bad_request();
129                res.head_mut().reason = Some("Unsupported WebSocket version");
130                res
131            }
132
133            HandshakeError::BadWebsocketKey => {
134                let mut res = Response::bad_request();
135                res.head_mut().reason = Some("Handshake error");
136                res
137            }
138        }
139    }
140}
141
142impl From<&HandshakeError> for Response<BoxBody> {
143    fn from(err: &HandshakeError) -> Self {
144        (*err).into()
145    }
146}
147
148pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
150    verify_handshake(req)?;
151    Ok(handshake_response(req))
152}
153
154pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
156    if req.method != Method::GET {
158        return Err(HandshakeError::GetMethodRequired);
159    }
160
161    let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) {
163        if let Ok(s) = hdr.to_str() {
164            s.to_ascii_lowercase().contains("websocket")
165        } else {
166            false
167        }
168    } else {
169        false
170    };
171    if !has_hdr {
172        return Err(HandshakeError::NoWebsocketUpgrade);
173    }
174
175    if !req.upgrade() {
177        return Err(HandshakeError::NoConnectionUpgrade);
178    }
179
180    if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
182        return Err(HandshakeError::NoVersionHeader);
183    }
184    let supported_ver = {
185        if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
186            hdr == "13" || hdr == "8" || hdr == "7"
187        } else {
188            false
189        }
190    };
191    if !supported_ver {
192        return Err(HandshakeError::UnsupportedVersion);
193    }
194
195    if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
197        return Err(HandshakeError::BadWebsocketKey);
198    }
199    Ok(())
200}
201
202pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
206    let key = {
207        let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
208        proto::hash_key(key.as_ref())
209    };
210
211    Response::build(StatusCode::SWITCHING_PROTOCOLS)
212        .upgrade("websocket")
213        .insert_header((
214            header::SEC_WEBSOCKET_ACCEPT,
215            HeaderValue::from_bytes(&key).unwrap(),
217        ))
218        .take()
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::{header, test::TestRequest};
225
226    #[test]
227    fn test_handshake() {
228        let req = TestRequest::default().method(Method::POST).finish();
229        assert_eq!(
230            HandshakeError::GetMethodRequired,
231            verify_handshake(req.head()).unwrap_err(),
232        );
233
234        let req = TestRequest::default().finish();
235        assert_eq!(
236            HandshakeError::NoWebsocketUpgrade,
237            verify_handshake(req.head()).unwrap_err(),
238        );
239
240        let req = TestRequest::default()
241            .insert_header((header::UPGRADE, header::HeaderValue::from_static("test")))
242            .finish();
243        assert_eq!(
244            HandshakeError::NoWebsocketUpgrade,
245            verify_handshake(req.head()).unwrap_err(),
246        );
247
248        let req = TestRequest::default()
249            .insert_header((
250                header::UPGRADE,
251                header::HeaderValue::from_static("websocket"),
252            ))
253            .finish();
254        assert_eq!(
255            HandshakeError::NoConnectionUpgrade,
256            verify_handshake(req.head()).unwrap_err(),
257        );
258
259        let req = TestRequest::default()
260            .insert_header((
261                header::UPGRADE,
262                header::HeaderValue::from_static("websocket"),
263            ))
264            .insert_header((
265                header::CONNECTION,
266                header::HeaderValue::from_static("upgrade"),
267            ))
268            .finish();
269        assert_eq!(
270            HandshakeError::NoVersionHeader,
271            verify_handshake(req.head()).unwrap_err(),
272        );
273
274        let req = TestRequest::default()
275            .insert_header((
276                header::UPGRADE,
277                header::HeaderValue::from_static("websocket"),
278            ))
279            .insert_header((
280                header::CONNECTION,
281                header::HeaderValue::from_static("upgrade"),
282            ))
283            .insert_header((
284                header::SEC_WEBSOCKET_VERSION,
285                header::HeaderValue::from_static("5"),
286            ))
287            .finish();
288        assert_eq!(
289            HandshakeError::UnsupportedVersion,
290            verify_handshake(req.head()).unwrap_err(),
291        );
292
293        let req = TestRequest::default()
294            .insert_header((
295                header::UPGRADE,
296                header::HeaderValue::from_static("websocket"),
297            ))
298            .insert_header((
299                header::CONNECTION,
300                header::HeaderValue::from_static("upgrade"),
301            ))
302            .insert_header((
303                header::SEC_WEBSOCKET_VERSION,
304                header::HeaderValue::from_static("13"),
305            ))
306            .finish();
307        assert_eq!(
308            HandshakeError::BadWebsocketKey,
309            verify_handshake(req.head()).unwrap_err(),
310        );
311
312        let req = TestRequest::default()
313            .insert_header((
314                header::UPGRADE,
315                header::HeaderValue::from_static("websocket"),
316            ))
317            .insert_header((
318                header::CONNECTION,
319                header::HeaderValue::from_static("upgrade"),
320            ))
321            .insert_header((
322                header::SEC_WEBSOCKET_VERSION,
323                header::HeaderValue::from_static("13"),
324            ))
325            .insert_header((
326                header::SEC_WEBSOCKET_KEY,
327                header::HeaderValue::from_static("13"),
328            ))
329            .finish();
330        assert_eq!(
331            StatusCode::SWITCHING_PROTOCOLS,
332            handshake_response(req.head()).finish().status()
333        );
334    }
335
336    #[test]
337    fn test_ws_error_http_response() {
338        let resp: Response<BoxBody> = HandshakeError::GetMethodRequired.into();
339        assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
340        let resp: Response<BoxBody> = HandshakeError::NoWebsocketUpgrade.into();
341        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
342        let resp: Response<BoxBody> = HandshakeError::NoConnectionUpgrade.into();
343        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
344        let resp: Response<BoxBody> = HandshakeError::NoVersionHeader.into();
345        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
346        let resp: Response<BoxBody> = HandshakeError::UnsupportedVersion.into();
347        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
348        let resp: Response<BoxBody> = HandshakeError::BadWebsocketKey.into();
349        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
350    }
351}