tide_disco/
socket.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the tide-disco library.
3
4// You should have received a copy of the MIT License
5// along with the tide-disco library. If not, see <https://mit-license.org/>.
6
7//! An interface for asynchronous communication with clients, using WebSockets.
8
9use crate::{
10    http::{content::Accept, mime},
11    method::Method,
12    request::{best_response_type, RequestError, RequestParams},
13    StatusCode,
14};
15use async_std::sync::Arc;
16use futures::{
17    future::BoxFuture,
18    sink,
19    stream::BoxStream,
20    task::{Context, Poll},
21    FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt,
22};
23use pin_project::pin_project;
24use serde::{de::DeserializeOwned, Serialize};
25use std::borrow::Cow;
26use std::fmt::{self, Display, Formatter};
27use std::marker::PhantomData;
28use std::pin::Pin;
29use tide_websockets::{
30    tungstenite::protocol::frame::{coding::CloseCode, CloseFrame},
31    Message, WebSocketConnection,
32};
33use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
34
35/// An error returned by a socket handler.
36///
37/// [SocketError] encapsulates application specific errors `E` returned by the user-installed
38/// handler itself. It also includes errors in the socket protocol, such as failures to turn
39/// messages sent by the user-installed handler into WebSockets messages.
40#[derive(Debug)]
41pub enum SocketError<E> {
42    AppSpecific(E),
43    Request(RequestError),
44    Binary(anyhow::Error),
45    Json(serde_json::Error),
46    WebSockets(tide_websockets::Error),
47    UnsupportedMessageType,
48    Closed,
49    IncorrectMethod { expected: Method, actual: Method },
50}
51
52impl<E> SocketError<E> {
53    pub fn status(&self) -> StatusCode {
54        match self {
55            Self::Request(_) | Self::UnsupportedMessageType | Self::IncorrectMethod { .. } => {
56                StatusCode::BAD_REQUEST
57            }
58            _ => StatusCode::INTERNAL_SERVER_ERROR,
59        }
60    }
61
62    pub fn code(&self) -> CloseCode {
63        CloseCode::Error
64    }
65
66    pub fn map_app_specific<E2>(self, f: &impl Fn(E) -> E2) -> SocketError<E2> {
67        match self {
68            Self::AppSpecific(e) => SocketError::AppSpecific(f(e)),
69            Self::Request(e) => SocketError::Request(e),
70            Self::Binary(e) => SocketError::Binary(e),
71            Self::Json(e) => SocketError::Json(e),
72            Self::WebSockets(e) => SocketError::WebSockets(e),
73            Self::UnsupportedMessageType => SocketError::UnsupportedMessageType,
74            Self::Closed => SocketError::Closed,
75            Self::IncorrectMethod { expected, actual } => {
76                SocketError::IncorrectMethod { expected, actual }
77            }
78        }
79    }
80}
81
82impl<E: Display> Display for SocketError<E> {
83    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84        match self {
85            Self::AppSpecific(e) => write!(f, "{}", e),
86            Self::Request(e) => write!(f, "{}", e),
87            Self::Binary(e) => write!(f, "error creating byte stream: {}", e),
88            Self::Json(e) => write!(f, "error creating JSON message: {}", e),
89            Self::WebSockets(e) => write!(f, "WebSockets protocol error: {}", e),
90            Self::UnsupportedMessageType => {
91                write!(f, "unsupported content type for WebSockets message")
92            }
93            Self::Closed => write!(f, "connection closed"),
94            Self::IncorrectMethod { expected, actual } => write!(
95                f,
96                "endpoint must be called as {}, but was called as {}",
97                expected, actual
98            ),
99        }
100    }
101}
102
103impl<E> From<RequestError> for SocketError<E> {
104    fn from(err: RequestError) -> Self {
105        Self::Request(err)
106    }
107}
108
109impl<E> From<anyhow::Error> for SocketError<E> {
110    fn from(err: anyhow::Error) -> Self {
111        Self::Binary(err)
112    }
113}
114
115impl<E> From<serde_json::Error> for SocketError<E> {
116    fn from(err: serde_json::Error) -> Self {
117        Self::Json(err)
118    }
119}
120
121impl<E> From<tide_websockets::Error> for SocketError<E> {
122    fn from(err: tide_websockets::Error) -> Self {
123        Self::WebSockets(err)
124    }
125}
126
127#[derive(Clone, Copy, Debug)]
128enum MessageType {
129    Binary,
130    Json,
131}
132
133/// A connection facilitating bi-directional, asynchronous communication with a client.
134///
135/// [Connection] implements [Stream], which can be used to receive `FromClient` messages from the
136/// client, and [Sink] which can be used to send `ToClient` messages to the client.
137#[pin_project]
138pub struct Connection<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType> {
139    #[pin]
140    conn: WebSocketConnection,
141    // [Sink] wrapper around `conn`
142    sink: Pin<Box<dyn Send + Sink<Message, Error = SocketError<Error>>>>,
143    accept: MessageType,
144    #[allow(clippy::type_complexity)]
145    _phantom: PhantomData<fn(&ToClient, &FromClient, &Error, &VER) -> ()>,
146}
147
148impl<ToClient: ?Sized, FromClient: DeserializeOwned, E, VER: StaticVersionType> Stream
149    for Connection<ToClient, FromClient, E, VER>
150{
151    type Item = Result<FromClient, SocketError<E>>;
152
153    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154        // Get a `Pin<&mut WebSocketConnection>` for the underlying connection, so we can use the
155        // `Stream` implementation of that field.
156        match self.project().conn.poll_next(cx) {
157            Poll::Ready(None) => Poll::Ready(None),
158            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
159            Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(match msg {
160                Message::Binary(bytes) => {
161                    Serializer::<VER>::deserialize(&bytes).map_err(SocketError::from)
162                }
163                Message::Text(s) => serde_json::from_str(&s).map_err(SocketError::from),
164                _ => Err(SocketError::UnsupportedMessageType),
165            })),
166            Poll::Pending => Poll::Pending,
167        }
168    }
169}
170
171impl<ToClient: Serialize + ?Sized, FromClient, E, VER: StaticVersionType> Sink<&ToClient>
172    for Connection<ToClient, FromClient, E, VER>
173{
174    type Error = SocketError<E>;
175
176    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.sink.as_mut().poll_ready(cx).map_err(SocketError::from)
178    }
179
180    fn start_send(mut self: Pin<&mut Self>, item: &ToClient) -> Result<(), Self::Error> {
181        let msg = match self.accept {
182            MessageType::Binary => Message::Binary(Serializer::<VER>::serialize(item)?),
183            MessageType::Json => Message::Text(serde_json::to_string(item)?),
184        };
185        self.sink.as_mut().start_send(msg)
186    }
187
188    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189        self.sink.as_mut().poll_flush(cx).map_err(SocketError::from)
190    }
191
192    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193        self.sink.as_mut().poll_close(cx).map_err(SocketError::from)
194    }
195}
196
197impl<ToClient: Serialize, FromClient, E, VER: StaticVersionType> Sink<ToClient>
198    for Connection<ToClient, FromClient, E, VER>
199{
200    type Error = SocketError<E>;
201
202    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203        Sink::<&ToClient>::poll_ready(self, cx)
204    }
205
206    fn start_send(self: Pin<&mut Self>, item: ToClient) -> Result<(), Self::Error> {
207        self.start_send(&item)
208    }
209
210    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211        Sink::<&ToClient>::poll_flush(self, cx)
212    }
213
214    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215        Sink::<&ToClient>::poll_close(self, cx)
216    }
217}
218
219impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType>
220    Connection<ToClient, FromClient, E, VER>
221{
222    #[allow(clippy::result_large_err)]
223    fn new(accept: &Accept, conn: WebSocketConnection) -> Result<Self, SocketError<E>> {
224        let ty = best_response_type(accept, &[mime::JSON, mime::BYTE_STREAM])?;
225        let ty = if ty == mime::JSON {
226            MessageType::Json
227        } else if ty == mime::BYTE_STREAM {
228            MessageType::Binary
229        } else {
230            unreachable!()
231        };
232        Ok(Self {
233            sink: Self::sink(conn.clone()),
234            conn,
235            accept: ty,
236            _phantom: Default::default(),
237        })
238    }
239
240    /// Wrap a `WebSocketConnection` in a type that implements `Sink<Message>`.
241    fn sink(
242        conn: WebSocketConnection,
243    ) -> Pin<Box<dyn Send + Sink<Message, Error = SocketError<E>>>> {
244        Box::pin(sink::unfold(conn, |conn, msg| async move {
245            conn.send(msg).await?;
246            Ok(conn)
247        }))
248    }
249}
250
251impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType> Clone
252    for Connection<ToClient, FromClient, E, VER>
253{
254    fn clone(&self) -> Self {
255        Self {
256            sink: Self::sink(self.conn.clone()),
257            conn: self.conn.clone(),
258            accept: self.accept,
259            _phantom: Default::default(),
260        }
261    }
262}
263
264pub(crate) type Handler<State, Error> = Box<
265    dyn 'static
266        + Send
267        + Sync
268        + Fn(RequestParams, WebSocketConnection, &State) -> BoxFuture<Result<(), SocketError<Error>>>,
269>;
270
271pub(crate) fn handler<State, Error, ToClient, FromClient, F, VER: StaticVersionType>(
272    f: F,
273) -> Handler<State, Error>
274where
275    F: 'static
276        + Send
277        + Sync
278        + Fn(
279            RequestParams,
280            Connection<ToClient, FromClient, Error, VER>,
281            &State,
282        ) -> BoxFuture<Result<(), Error>>,
283    State: 'static + Send + Sync,
284    ToClient: 'static + Serialize + ?Sized,
285    FromClient: 'static + DeserializeOwned,
286    Error: 'static + Send + Display,
287{
288    raw_handler(move |req, conn, state| {
289        f(req, conn, state)
290            .map_err(SocketError::AppSpecific)
291            .boxed()
292    })
293}
294
295struct StreamHandler<F, VER: StaticVersionType>(F, PhantomData<VER>);
296
297impl<F, VER: StaticVersionType> StreamHandler<F, VER> {
298    fn handle<'a, State, Error, Msg>(
299        &self,
300        req: RequestParams,
301        mut conn: Connection<Msg, (), Error, VER>,
302        state: &'a State,
303    ) -> BoxFuture<'a, Result<(), SocketError<Error>>>
304    where
305        F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
306        State: 'static + Send + Sync,
307        Msg: 'static + Serialize + Send + Sync,
308        Error: 'static + Send,
309        VER: 'static + Send + Sync,
310    {
311        let mut stream = (self.0)(req, state);
312        async move {
313            while let Some(msg) = stream.next().await {
314                conn.send(&msg.map_err(SocketError::AppSpecific)?).await?;
315            }
316            Ok(())
317        }
318        .boxed()
319    }
320}
321
322pub(crate) fn stream_handler<State, Error, Msg, F, VER>(f: F) -> Handler<State, Error>
323where
324    F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
325    State: 'static + Send + Sync,
326    Msg: 'static + Serialize + Send + Sync,
327    Error: 'static + Send + Display,
328    VER: 'static + Send + Sync + StaticVersionType,
329{
330    let handler: StreamHandler<F, VER> = StreamHandler(f, Default::default());
331    raw_handler(move |req, conn, state| handler.handle(req, conn, state))
332}
333
334fn raw_handler<State, Error, ToClient, FromClient, F, VER>(f: F) -> Handler<State, Error>
335where
336    F: 'static
337        + Send
338        + Sync
339        + Fn(
340            RequestParams,
341            Connection<ToClient, FromClient, Error, VER>,
342            &State,
343        ) -> BoxFuture<Result<(), SocketError<Error>>>,
344    State: 'static + Send + Sync,
345    ToClient: 'static + Serialize + ?Sized,
346    FromClient: 'static + DeserializeOwned,
347    Error: 'static + Send + Display,
348    VER: StaticVersionType,
349{
350    let close = |conn: WebSocketConnection, res: Result<(), SocketError<Error>>| async move {
351        // When the handler finishes, send a close message. If there was an error, include the error
352        // message.
353        let msg = res.as_ref().err().map(|err| CloseFrame {
354            code: err.code(),
355            reason: Cow::Owned(err.to_string()),
356        });
357        conn.send(Message::Close(msg)).await?;
358        res
359    };
360    Box::new(move |req, raw_conn, state| {
361        let accept = match req.accept() {
362            Ok(accept) => accept,
363            Err(err) => return close(raw_conn, Err(err.into())).boxed(),
364        };
365        let conn = match Connection::new(&accept, raw_conn.clone()) {
366            Ok(conn) => conn,
367            Err(err) => return close(raw_conn, Err(err)).boxed(),
368        };
369        f(req, conn, state)
370            .then(move |res| close(raw_conn, res))
371            .boxed()
372    })
373}
374
375struct MapErr<State, Error, F> {
376    handler: Handler<State, Error>,
377    map: Arc<F>,
378}
379
380impl<State, Error, F> MapErr<State, Error, F> {
381    fn handle<'a, Error2>(
382        &self,
383        req: RequestParams,
384        conn: WebSocketConnection,
385        state: &'a State,
386    ) -> BoxFuture<'a, Result<(), SocketError<Error2>>>
387    where
388        F: 'static + Send + Sync + Fn(Error) -> Error2,
389        State: 'static + Send + Sync,
390        Error: 'static,
391    {
392        let map = self.map.clone();
393        let fut = (self.handler)(req, conn, state);
394        async move { fut.await.map_err(|err| err.map_app_specific(&*map)) }.boxed()
395    }
396}
397
398pub(crate) fn map_err<State, Error, Error2>(
399    h: Handler<State, Error>,
400    f: impl 'static + Send + Sync + Fn(Error) -> Error2,
401) -> Handler<State, Error2>
402where
403    State: 'static + Send + Sync,
404    Error: 'static,
405{
406    let handler = MapErr {
407        handler: h,
408        map: Arc::new(f),
409    };
410    Box::new(move |req, conn, state| handler.handle(req, conn, state))
411}