tide_disco/
socket.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the tide-disco library.
3
4// You should have received a copy of the MIT License
5// along with the tide-disco library. If not, see <https://mit-license.org/>.
6
7//! An interface for asynchronous communication with clients, using WebSockets.
8
9use crate::{
10    http::{content::Accept, mime},
11    method::Method,
12    request::{best_response_type, RequestError, RequestParams},
13    StatusCode,
14};
15use async_std::sync::Arc;
16use futures::{
17    future::BoxFuture,
18    select, sink,
19    stream::BoxStream,
20    task::{Context, Poll},
21    FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt,
22};
23use pin_project::pin_project;
24use serde::{de::DeserializeOwned, Serialize};
25use std::borrow::Cow;
26use std::fmt::{self, Display, Formatter};
27use std::marker::PhantomData;
28use std::pin::Pin;
29use tide_websockets::{
30    tungstenite::protocol::frame::{coding::CloseCode, CloseFrame},
31    Message, WebSocketConnection,
32};
33use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
34
35/// An error returned by a socket handler.
36///
37/// [SocketError] encapsulates application specific errors `E` returned by the user-installed
38/// handler itself. It also includes errors in the socket protocol, such as failures to turn
39/// messages sent by the user-installed handler into WebSockets messages.
40#[derive(Debug)]
41pub enum SocketError<E> {
42    AppSpecific(E),
43    Request(RequestError),
44    Binary(anyhow::Error),
45    Json(serde_json::Error),
46    WebSockets(tide_websockets::Error),
47    UnsupportedMessageType,
48    Closed,
49    IncorrectMethod { expected: Method, actual: Method },
50}
51
52impl<E> SocketError<E> {
53    pub fn status(&self) -> StatusCode {
54        match self {
55            Self::Request(_) | Self::UnsupportedMessageType | Self::IncorrectMethod { .. } => {
56                StatusCode::BAD_REQUEST
57            }
58            _ => StatusCode::INTERNAL_SERVER_ERROR,
59        }
60    }
61
62    pub fn code(&self) -> CloseCode {
63        CloseCode::Error
64    }
65
66    pub fn map_app_specific<E2>(self, f: &impl Fn(E) -> E2) -> SocketError<E2> {
67        match self {
68            Self::AppSpecific(e) => SocketError::AppSpecific(f(e)),
69            Self::Request(e) => SocketError::Request(e),
70            Self::Binary(e) => SocketError::Binary(e),
71            Self::Json(e) => SocketError::Json(e),
72            Self::WebSockets(e) => SocketError::WebSockets(e),
73            Self::UnsupportedMessageType => SocketError::UnsupportedMessageType,
74            Self::Closed => SocketError::Closed,
75            Self::IncorrectMethod { expected, actual } => {
76                SocketError::IncorrectMethod { expected, actual }
77            }
78        }
79    }
80}
81
82impl<E: Display> Display for SocketError<E> {
83    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84        match self {
85            Self::AppSpecific(e) => write!(f, "{}", e),
86            Self::Request(e) => write!(f, "{}", e),
87            Self::Binary(e) => write!(f, "error creating byte stream: {}", e),
88            Self::Json(e) => write!(f, "error creating JSON message: {}", e),
89            Self::WebSockets(e) => write!(f, "WebSockets protocol error: {}", e),
90            Self::UnsupportedMessageType => {
91                write!(f, "unsupported content type for WebSockets message")
92            }
93            Self::Closed => write!(f, "connection closed"),
94            Self::IncorrectMethod { expected, actual } => write!(
95                f,
96                "endpoint must be called as {}, but was called as {}",
97                expected, actual
98            ),
99        }
100    }
101}
102
103impl<E> From<RequestError> for SocketError<E> {
104    fn from(err: RequestError) -> Self {
105        Self::Request(err)
106    }
107}
108
109impl<E> From<anyhow::Error> for SocketError<E> {
110    fn from(err: anyhow::Error) -> Self {
111        Self::Binary(err)
112    }
113}
114
115impl<E> From<serde_json::Error> for SocketError<E> {
116    fn from(err: serde_json::Error) -> Self {
117        Self::Json(err)
118    }
119}
120
121impl<E> From<tide_websockets::Error> for SocketError<E> {
122    fn from(err: tide_websockets::Error) -> Self {
123        Self::WebSockets(err)
124    }
125}
126
127#[derive(Clone, Copy, Debug)]
128enum MessageType {
129    Binary,
130    Json,
131}
132
133/// A connection facilitating bi-directional, asynchronous communication with a client.
134///
135/// [Connection] implements [Stream], which can be used to receive `FromClient` messages from the
136/// client, and [Sink] which can be used to send `ToClient` messages to the client.
137#[pin_project]
138pub struct Connection<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType> {
139    #[pin]
140    conn: WebSocketConnection,
141    // [Sink] wrapper around `conn`
142    sink: Pin<Box<dyn Send + Sink<Message, Error = SocketError<Error>>>>,
143    accept: MessageType,
144    #[allow(clippy::type_complexity)]
145    _phantom: PhantomData<fn(&ToClient, &FromClient, &Error, &VER) -> ()>,
146}
147
148impl<ToClient: ?Sized, FromClient: DeserializeOwned, E, VER: StaticVersionType> Stream
149    for Connection<ToClient, FromClient, E, VER>
150{
151    type Item = Result<FromClient, SocketError<E>>;
152
153    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154        // Get a `Pin<&mut WebSocketConnection>` for the underlying connection, so we can use the
155        // `Stream` implementation of that field.
156        match self.project().conn.poll_next(cx) {
157            Poll::Ready(None) => Poll::Ready(None),
158            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
159            Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(match msg {
160                Message::Binary(bytes) => {
161                    Serializer::<VER>::deserialize(&bytes).map_err(SocketError::from)
162                }
163                Message::Text(s) => serde_json::from_str(&s).map_err(SocketError::from),
164                _ => Err(SocketError::UnsupportedMessageType),
165            })),
166            Poll::Pending => Poll::Pending,
167        }
168    }
169}
170
171impl<ToClient: Serialize + ?Sized, FromClient, E, VER: StaticVersionType> Sink<&ToClient>
172    for Connection<ToClient, FromClient, E, VER>
173{
174    type Error = SocketError<E>;
175
176    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.sink.as_mut().poll_ready(cx).map_err(SocketError::from)
178    }
179
180    fn start_send(mut self: Pin<&mut Self>, item: &ToClient) -> Result<(), Self::Error> {
181        let msg = match self.accept {
182            MessageType::Binary => Message::Binary(Serializer::<VER>::serialize(item)?),
183            MessageType::Json => Message::Text(serde_json::to_string(item)?),
184        };
185        self.sink.as_mut().start_send(msg)
186    }
187
188    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189        self.sink.as_mut().poll_flush(cx).map_err(SocketError::from)
190    }
191
192    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193        self.sink.as_mut().poll_close(cx).map_err(SocketError::from)
194    }
195}
196
197impl<ToClient: Serialize, FromClient, E, VER: StaticVersionType> Sink<ToClient>
198    for Connection<ToClient, FromClient, E, VER>
199{
200    type Error = SocketError<E>;
201
202    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203        Sink::<&ToClient>::poll_ready(self, cx)
204    }
205
206    fn start_send(self: Pin<&mut Self>, item: ToClient) -> Result<(), Self::Error> {
207        self.start_send(&item)
208    }
209
210    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211        Sink::<&ToClient>::poll_flush(self, cx)
212    }
213
214    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215        Sink::<&ToClient>::poll_close(self, cx)
216    }
217}
218
219impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType>
220    Connection<ToClient, FromClient, E, VER>
221{
222    #[allow(clippy::result_large_err)]
223    fn new(accept: &Accept, conn: WebSocketConnection) -> Result<Self, SocketError<E>> {
224        let ty = best_response_type(accept, &[mime::JSON, mime::BYTE_STREAM])?;
225        let ty = if ty == mime::JSON {
226            MessageType::Json
227        } else if ty == mime::BYTE_STREAM {
228            MessageType::Binary
229        } else {
230            unreachable!()
231        };
232        Ok(Self {
233            sink: Self::sink(conn.clone()),
234            conn,
235            accept: ty,
236            _phantom: Default::default(),
237        })
238    }
239
240    /// Wrap a `WebSocketConnection` in a type that implements `Sink<Message>`.
241    fn sink(
242        conn: WebSocketConnection,
243    ) -> Pin<Box<dyn Send + Sink<Message, Error = SocketError<E>>>> {
244        Box::pin(sink::unfold(conn, |conn, msg| async move {
245            conn.send(msg).await?;
246            Ok(conn)
247        }))
248    }
249}
250
251impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType> Clone
252    for Connection<ToClient, FromClient, E, VER>
253{
254    fn clone(&self) -> Self {
255        Self {
256            sink: Self::sink(self.conn.clone()),
257            conn: self.conn.clone(),
258            accept: self.accept,
259            _phantom: Default::default(),
260        }
261    }
262}
263
264pub(crate) type Handler<State, Error> = Box<
265    dyn 'static
266        + Send
267        + Sync
268        + Fn(RequestParams, WebSocketConnection, &State) -> BoxFuture<Result<(), SocketError<Error>>>,
269>;
270
271pub(crate) fn handler<State, Error, ToClient, FromClient, F, VER: StaticVersionType>(
272    f: F,
273) -> Handler<State, Error>
274where
275    F: 'static
276        + Send
277        + Sync
278        + Fn(
279            RequestParams,
280            Connection<ToClient, FromClient, Error, VER>,
281            &State,
282        ) -> BoxFuture<Result<(), Error>>,
283    State: 'static + Send + Sync,
284    ToClient: 'static + Serialize + ?Sized,
285    FromClient: 'static + DeserializeOwned,
286    Error: 'static + Send + Display,
287{
288    raw_handler(move |req, conn, state| {
289        f(req, conn, state)
290            .map_err(SocketError::AppSpecific)
291            .boxed()
292    })
293}
294
295struct StreamHandler<F, VER: StaticVersionType>(F, PhantomData<VER>);
296
297impl<F, VER: StaticVersionType> StreamHandler<F, VER> {
298    fn handle<'a, State, Error, Msg>(
299        &self,
300        req: RequestParams,
301        conn: Connection<Msg, (), Error, VER>,
302        state: &'a State,
303    ) -> BoxFuture<'a, Result<(), SocketError<Error>>>
304    where
305        F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
306        State: 'static + Send + Sync,
307        Msg: 'static + Serialize + Send + Sync,
308        Error: 'static + Send,
309        VER: 'static + Send + Sync,
310    {
311        let mut stream = (self.0)(req, state).fuse();
312        async move {
313            // Appease the borrow checker, this is a cheap clone
314            let (mut send, mut recv) = (conn.clone(), conn);
315
316            // Neither stream is documented to be cancel-safe, so we store the futures outside select
317            let mut item_fut = stream.next();
318            let mut client_fut = recv.next().fuse();
319
320            loop {
321                select! {
322                    item = item_fut => {
323                        match item {
324                            Some(msg) => {
325                                send.send(&msg.map_err(SocketError::AppSpecific)?).await?;
326                                item_fut = stream.next();
327                            }
328                            None => {
329                                break;
330                            }
331                        }
332                    }
333                    // We don't actually expect to receive anything from the client,
334                    // it is being polled only to handle connection closure by the client
335                    client_msg = client_fut => {
336                        client_fut = recv.next().fuse();
337                        match client_msg {
338                            None => return Ok(()),
339                            Some(Err(e)) => return Err(e),
340                            _ => {}
341                        }
342                    }
343                };
344            }
345            Ok(())
346        }
347        .boxed()
348    }
349}
350
351pub(crate) fn stream_handler<State, Error, Msg, F, VER>(f: F) -> Handler<State, Error>
352where
353    F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
354    State: 'static + Send + Sync,
355    Msg: 'static + Serialize + Send + Sync,
356    Error: 'static + Send + Display,
357    VER: 'static + Send + Sync + StaticVersionType,
358{
359    let handler: StreamHandler<F, VER> = StreamHandler(f, Default::default());
360    raw_handler(move |req, conn, state| handler.handle(req, conn, state))
361}
362
363fn raw_handler<State, Error, ToClient, FromClient, F, VER>(f: F) -> Handler<State, Error>
364where
365    F: 'static
366        + Send
367        + Sync
368        + Fn(
369            RequestParams,
370            Connection<ToClient, FromClient, Error, VER>,
371            &State,
372        ) -> BoxFuture<Result<(), SocketError<Error>>>,
373    State: 'static + Send + Sync,
374    ToClient: 'static + Serialize + ?Sized,
375    FromClient: 'static + DeserializeOwned,
376    Error: 'static + Send + Display,
377    VER: StaticVersionType,
378{
379    let close = |conn: WebSocketConnection, res: Result<(), SocketError<Error>>| async move {
380        // When the handler finishes, send a close message. If there was an error, include the error
381        // message.
382        let msg = res.as_ref().err().map(|err| CloseFrame {
383            code: err.code(),
384            reason: Cow::Owned(err.to_string()),
385        });
386        conn.send(Message::Close(msg)).await?;
387        res
388    };
389    Box::new(move |req, raw_conn, state| {
390        let accept = match req.accept() {
391            Ok(accept) => accept,
392            Err(err) => return close(raw_conn, Err(err.into())).boxed(),
393        };
394        let conn = match Connection::new(&accept, raw_conn.clone()) {
395            Ok(conn) => conn,
396            Err(err) => return close(raw_conn, Err(err)).boxed(),
397        };
398        f(req, conn, state)
399            .then(move |res| close(raw_conn, res))
400            .boxed()
401    })
402}
403
404struct MapErr<State, Error, F> {
405    handler: Handler<State, Error>,
406    map: Arc<F>,
407}
408
409impl<State, Error, F> MapErr<State, Error, F> {
410    fn handle<'a, Error2>(
411        &self,
412        req: RequestParams,
413        conn: WebSocketConnection,
414        state: &'a State,
415    ) -> BoxFuture<'a, Result<(), SocketError<Error2>>>
416    where
417        F: 'static + Send + Sync + Fn(Error) -> Error2,
418        State: 'static + Send + Sync,
419        Error: 'static,
420    {
421        let map = self.map.clone();
422        let fut = (self.handler)(req, conn, state);
423        async move { fut.await.map_err(|err| err.map_app_specific(&*map)) }.boxed()
424    }
425}
426
427pub(crate) fn map_err<State, Error, Error2>(
428    h: Handler<State, Error>,
429    f: impl 'static + Send + Sync + Fn(Error) -> Error2,
430) -> Handler<State, Error2>
431where
432    State: 'static + Send + Sync,
433    Error: 'static,
434{
435    let handler = MapErr {
436        handler: h,
437        map: Arc::new(f),
438    };
439    Box::new(move |req, conn, state| handler.handle(req, conn, state))
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::{error::ServerError, testing::test_ws_client, Api, App, Url};
446    use async_std::task::{sleep, spawn};
447    use async_tungstenite::tungstenite::Message as TungsteniteMessage;
448    use futures::{stream, StreamExt};
449    use pin_project::pinned_drop;
450    use portpicker::pick_unused_port;
451    use std::{
452        sync::{
453            atomic::{AtomicBool, Ordering},
454            Arc,
455        },
456        time::Duration,
457    };
458    use vbs::version::StaticVersion;
459
460    type StaticVer01 = StaticVersion<0, 1>;
461
462    #[pin_project(PinnedDrop)]
463    struct DropStream<S: Stream> {
464        #[pin]
465        stream: S,
466        dropped: Arc<AtomicBool>,
467    }
468
469    impl<S: Stream> Stream for DropStream<S> {
470        type Item = S::Item;
471
472        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
473            let stream = self.project().stream;
474            stream.poll_next(cx)
475        }
476    }
477
478    #[pinned_drop]
479    impl<S: Stream> PinnedDrop for DropStream<S> {
480        fn drop(self: Pin<&mut Self>) {
481            self.dropped.store(true, Ordering::SeqCst);
482        }
483    }
484
485    #[async_std::test]
486    async fn test_stream_handler_client_closure() {
487        // Setup: Create a simple API with a stream endpoint
488        let port = pick_unused_port().expect("No ports available");
489
490        let mut app = App::<(), ServerError>::with_state(());
491        let toml_content = r#"
492            [meta]
493            FORMAT_VERSION = "0.1.0"
494
495            [route.stream_test]
496            PATH = ["/stream"]
497            METHOD = "SOCKET"
498            "#;
499
500        let mut api =
501            Api::<(), ServerError, StaticVer01>::new(toml_content.parse::<toml::Value>().unwrap())
502                .unwrap();
503
504        // Register a stream handler that sends multiple messages and indicates
505        // whether it was dropped
506        let dropped = Arc::new(AtomicBool::new(false));
507        let _dropped = dropped.clone();
508        api.stream("stream_test", move |_req, _state| {
509            Box::pin(DropStream {
510                stream: stream::iter(0..).map(Result::Ok),
511                dropped: _dropped.clone(),
512            })
513        })
514        .unwrap();
515
516        app.register_module("test", api).unwrap();
517
518        // Start the server
519        spawn(async move {
520            app.serve(format!("127.0.0.1:{}", port), StaticVer01::instance())
521                .await
522                .unwrap();
523        });
524
525        // Give the server time to start
526        sleep(Duration::from_millis(500)).await;
527
528        // Connect as a client
529        let url = Url::parse(&format!("http://127.0.0.1:{}/test/stream", port)).unwrap();
530        let mut ws_stream = test_ws_client(url).await;
531
532        // Receive a few messages
533        let mut received_count = 0;
534        for _ in 0..5 {
535            if let Some(Ok(TungsteniteMessage::Text(msg))) = ws_stream.next().await {
536                let parsed: usize = serde_json::from_str(&msg).unwrap();
537                assert_eq!(parsed, received_count);
538                received_count += 1;
539            }
540        }
541
542        // Close the client connection
543        ws_stream
544            .close(None)
545            .await
546            .expect("Failed to close connection");
547
548        // Wait a bit to ensure the server processes the closure
549        sleep(Duration::from_millis(300)).await;
550
551        // The underlying stream should've been dropped
552        assert!(dropped.load(Ordering::SeqCst));
553    }
554}